From 4af2c35c2e70be177a52455c1d2ab33d4d148b01 Mon Sep 17 00:00:00 2001 From: coolneng Date: Wed, 9 Dec 2020 12:49:25 +0100 Subject: [PATCH] Plot confusion matrix and simplify plot selection --- src/processing.py | 80 ++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 46 deletions(-) diff --git a/src/processing.py b/src/processing.py index f648096..4d6a276 100644 --- a/src/processing.py +++ b/src/processing.py @@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import scale from sklearn.svm import LinearSVC from sklearn.tree import DecisionTreeClassifier -from seaborn import set_theme +from seaborn import set_theme, set_style, heatmap, FacetGrid from matplotlib.pyplot import * from pandas import DataFrame @@ -53,41 +53,41 @@ def predict_data(data, target, model, results): return populated_results -def plot_roc_auc_curve(model, results): - rounded_auc = round(results.loc[model]["auc"], 3) - plot( - results.loc[model]["fpr"], - results.loc[model]["tpr"], - label=f"{model} , AUC={rounded_auc}", - ) +def plot_roc_auc_curve(results): + fig = figure(figsize=(8, 6)) + for model in results.index: + rounded_auc = round(results.loc[model]["auc"], 3) + plot( + results.loc[model]["fpr"], + results.loc[model]["tpr"], + label=f"{model} , AUC={rounded_auc}", + ) xticks(arange(0.0, 1.0, step=0.1)) yticks(arange(0.0, 1.0, step=0.1)) legend(loc="lower right") - - -def plot_confusion_matrix(model, results): - matrix = results.loc[model]["confusion_matrix"] - classes = ["Negative", "Positive"] - for item in matrix: - text(x=0.5, y=0.5, s=item) - xticks(ticks=arange(len(classes)), labels=classes) - yticks(ticks=arange(len(classes)), labels=classes) - - -def choose_plot_type(type, model, results): - if type == "roc": - plot_roc_auc_curve(model, results) - elif type == "confusion_matrix": - plot_confusion_matrix(model, results) - - -def plot_individual_figure(results, type, x_axis, y_axis, fig_title): - fig = figure(figsize=(8, 6)) - for model in results.index: - choose_plot_type(type, model, results) - xlabel(x_axis) - ylabel(y_axis) + fig_title = "ROC AUC curve" title(fig_title) + xlabel("False positive rate") + ylabel("True positive rate") + fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") + + +def plot_confusion_matrix(results): + set_style("white") + matrix = results.filter(items=["model", "confusion_matrix"]) + fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6)) + for i in range(len(axes)): + heatmap( + ax=axes[i], + data=matrix.iloc[i]["confusion_matrix"], + cmap="Blues", + square=True, + annot=True, + cbar=False, + ) + axes[i].set_title(matrix.index[i]) + fig_title = "Confusion Matrix" + suptitle(fig_title) show() fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") @@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title): # TODO Add cross_val_score def plot_all_figures(results): set_theme() - plot_individual_figure( - results, - type="roc", - x_axis="False positive rate", - y_axis="True positive rate", - fig_title="ROC AUC curve", - ) - plot_individual_figure( - results, - type="confusion_matrix", - x_axis="Predicted values", - y_axis="Real values", - fig_title="Confusion Matrix", - ) + plot_roc_auc_curve(results=results) + plot_confusion_matrix(results=results) def create_result_dataframes():