Plot attribute correlation

This commit is contained in:
coolneng 2020-12-09 14:10:16 +01:00
parent 4af2c35c2e
commit ef043650b5
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 31 additions and 8 deletions

View File

@ -1,15 +1,14 @@
from numpy import mean, arange from numpy import mean, arange
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
from sklearn.model_selection import cross_val_predict
from sklearn.naive_bayes import GaussianNB from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import scale from sklearn.preprocessing import scale
from sklearn.svm import LinearSVC from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from seaborn import set_theme, set_style, heatmap, FacetGrid from seaborn import set_theme, set_style, heatmap, countplot
from matplotlib.pyplot import * from matplotlib.pyplot import *
from pandas import DataFrame from pandas import DataFrame, cut
from sys import argv from sys import argv
@ -92,11 +91,28 @@ def plot_confusion_matrix(results):
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
# TODO Add cross_val_score def plot_attributes_correlation(data, target):
def plot_all_figures(results): transformed_data = transform_dataframe(data, target)
fig, axes = subplots(nrows=5, ncols=1, figsize=(8, 6))
for i in range(len(axes)):
countplot(
ax=axes[i],
x=transformed_data.columns[i],
data=transformed_data,
hue="Severity",
)
axes[i].set_title(transformed_data.columns[i])
fig_title = "Attribute's correlation"
suptitle(fig_title)
show()
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
def plot_all_figures(results, data, target):
set_theme() set_theme()
plot_roc_auc_curve(results=results) # plot_roc_auc_curve(results=results)
plot_confusion_matrix(results=results) # plot_confusion_matrix(results=results)
plot_attributes_correlation(data=data, target=target)
def create_result_dataframes(): def create_result_dataframes():
@ -127,6 +143,13 @@ def rename_model(model):
return mapping[model] return mapping[model]
def transform_dataframe(data, target):
joined_df = data.join(target)
binned_df = joined_df.copy()
binned_df["Age"] = cut(x=joined_df["Age"], bins=[15, 30, 45, 60, 75])
return binned_df
def usage(): def usage():
print("Usage: " + argv[0] + "<preprocessing action>") print("Usage: " + argv[0] + "<preprocessing action>")
print("preprocessing actions:") print("preprocessing actions:")
@ -149,7 +172,7 @@ def main():
individual_result.append(model_results) individual_result.append(model_results)
) )
indexed_results = complete_results.set_index("model") indexed_results = complete_results.set_index("model")
plot_all_figures(results=indexed_results) plot_all_figures(results=indexed_results, data=data, target=target)
if __name__ == "__main__": if __name__ == "__main__":