diff --git a/src/processing.py b/src/processing.py index 4d6a276..af17e1e 100644 --- a/src/processing.py +++ b/src/processing.py @@ -1,15 +1,14 @@ from numpy import mean, arange from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve -from sklearn.model_selection import cross_val_predict from sklearn.naive_bayes import GaussianNB from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import scale from sklearn.svm import LinearSVC from sklearn.tree import DecisionTreeClassifier -from seaborn import set_theme, set_style, heatmap, FacetGrid +from seaborn import set_theme, set_style, heatmap, countplot from matplotlib.pyplot import * -from pandas import DataFrame +from pandas import DataFrame, cut from sys import argv @@ -92,11 +91,28 @@ def plot_confusion_matrix(results): fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") -# TODO Add cross_val_score -def plot_all_figures(results): +def plot_attributes_correlation(data, target): + transformed_data = transform_dataframe(data, target) + fig, axes = subplots(nrows=5, ncols=1, figsize=(8, 6)) + for i in range(len(axes)): + countplot( + ax=axes[i], + x=transformed_data.columns[i], + data=transformed_data, + hue="Severity", + ) + axes[i].set_title(transformed_data.columns[i]) + fig_title = "Attribute's correlation" + suptitle(fig_title) + show() + fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") + + +def plot_all_figures(results, data, target): set_theme() - plot_roc_auc_curve(results=results) - plot_confusion_matrix(results=results) + # plot_roc_auc_curve(results=results) + # plot_confusion_matrix(results=results) + plot_attributes_correlation(data=data, target=target) def create_result_dataframes(): @@ -127,6 +143,13 @@ def rename_model(model): return mapping[model] +def transform_dataframe(data, target): + joined_df = data.join(target) + binned_df = joined_df.copy() + binned_df["Age"] = cut(x=joined_df["Age"], bins=[15, 30, 45, 60, 75]) + return binned_df + + def usage(): print("Usage: " + argv[0] + "") print("preprocessing actions:") @@ -149,7 +172,7 @@ def main(): individual_result.append(model_results) ) indexed_results = complete_results.set_index("model") - plot_all_figures(results=indexed_results) + plot_all_figures(results=indexed_results, data=data, target=target) if __name__ == "__main__":