Plot attribute correlation

This commit is contained in:
coolneng 2020-12-09 14:10:16 +01:00
parent 4af2c35c2e
commit ef043650b5
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 31 additions and 8 deletions

View File

@ -1,15 +1,14 @@
from numpy import mean, arange
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
from sklearn.model_selection import cross_val_predict
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import scale
from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier
from seaborn import set_theme, set_style, heatmap, FacetGrid
from seaborn import set_theme, set_style, heatmap, countplot
from matplotlib.pyplot import *
from pandas import DataFrame
from pandas import DataFrame, cut
from sys import argv
@ -92,11 +91,28 @@ def plot_confusion_matrix(results):
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
# TODO Add cross_val_score
def plot_all_figures(results):
def plot_attributes_correlation(data, target):
transformed_data = transform_dataframe(data, target)
fig, axes = subplots(nrows=5, ncols=1, figsize=(8, 6))
for i in range(len(axes)):
countplot(
ax=axes[i],
x=transformed_data.columns[i],
data=transformed_data,
hue="Severity",
)
axes[i].set_title(transformed_data.columns[i])
fig_title = "Attribute's correlation"
suptitle(fig_title)
show()
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
def plot_all_figures(results, data, target):
set_theme()
plot_roc_auc_curve(results=results)
plot_confusion_matrix(results=results)
# plot_roc_auc_curve(results=results)
# plot_confusion_matrix(results=results)
plot_attributes_correlation(data=data, target=target)
def create_result_dataframes():
@ -127,6 +143,13 @@ def rename_model(model):
return mapping[model]
def transform_dataframe(data, target):
joined_df = data.join(target)
binned_df = joined_df.copy()
binned_df["Age"] = cut(x=joined_df["Age"], bins=[15, 30, 45, 60, 75])
return binned_df
def usage():
print("Usage: " + argv[0] + "<preprocessing action>")
print("preprocessing actions:")
@ -149,7 +172,7 @@ def main():
individual_result.append(model_results)
)
indexed_results = complete_results.set_index("model")
plot_all_figures(results=indexed_results)
plot_all_figures(results=indexed_results, data=data, target=target)
if __name__ == "__main__":