Plot confusion matrix and simplify plot selection

This commit is contained in:
coolneng 2020-12-09 12:49:25 +01:00
parent bed9f1d0e9
commit 4af2c35c2e
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 34 additions and 46 deletions

View File

@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import scale from sklearn.preprocessing import scale
from sklearn.svm import LinearSVC from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from seaborn import set_theme from seaborn import set_theme, set_style, heatmap, FacetGrid
from matplotlib.pyplot import * from matplotlib.pyplot import *
from pandas import DataFrame from pandas import DataFrame
@ -53,7 +53,9 @@ def predict_data(data, target, model, results):
return populated_results return populated_results
def plot_roc_auc_curve(model, results): def plot_roc_auc_curve(results):
fig = figure(figsize=(8, 6))
for model in results.index:
rounded_auc = round(results.loc[model]["auc"], 3) rounded_auc = round(results.loc[model]["auc"], 3)
plot( plot(
results.loc[model]["fpr"], results.loc[model]["fpr"],
@ -63,31 +65,29 @@ def plot_roc_auc_curve(model, results):
xticks(arange(0.0, 1.0, step=0.1)) xticks(arange(0.0, 1.0, step=0.1))
yticks(arange(0.0, 1.0, step=0.1)) yticks(arange(0.0, 1.0, step=0.1))
legend(loc="lower right") legend(loc="lower right")
fig_title = "ROC AUC curve"
def plot_confusion_matrix(model, results):
matrix = results.loc[model]["confusion_matrix"]
classes = ["Negative", "Positive"]
for item in matrix:
text(x=0.5, y=0.5, s=item)
xticks(ticks=arange(len(classes)), labels=classes)
yticks(ticks=arange(len(classes)), labels=classes)
def choose_plot_type(type, model, results):
if type == "roc":
plot_roc_auc_curve(model, results)
elif type == "confusion_matrix":
plot_confusion_matrix(model, results)
def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
fig = figure(figsize=(8, 6))
for model in results.index:
choose_plot_type(type, model, results)
xlabel(x_axis)
ylabel(y_axis)
title(fig_title) title(fig_title)
xlabel("False positive rate")
ylabel("True positive rate")
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
def plot_confusion_matrix(results):
set_style("white")
matrix = results.filter(items=["model", "confusion_matrix"])
fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6))
for i in range(len(axes)):
heatmap(
ax=axes[i],
data=matrix.iloc[i]["confusion_matrix"],
cmap="Blues",
square=True,
annot=True,
cbar=False,
)
axes[i].set_title(matrix.index[i])
fig_title = "Confusion Matrix"
suptitle(fig_title)
show() show()
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
# TODO Add cross_val_score # TODO Add cross_val_score
def plot_all_figures(results): def plot_all_figures(results):
set_theme() set_theme()
plot_individual_figure( plot_roc_auc_curve(results=results)
results, plot_confusion_matrix(results=results)
type="roc",
x_axis="False positive rate",
y_axis="True positive rate",
fig_title="ROC AUC curve",
)
plot_individual_figure(
results,
type="confusion_matrix",
x_axis="Predicted values",
y_axis="Real values",
fig_title="Confusion Matrix",
)
def create_result_dataframes(): def create_result_dataframes():