Plot confusion matrix and simplify plot selection

This commit is contained in:
coolneng 2020-12-09 12:49:25 +01:00
parent bed9f1d0e9
commit 4af2c35c2e
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 34 additions and 46 deletions

View File

@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import scale
from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier
from seaborn import set_theme
from seaborn import set_theme, set_style, heatmap, FacetGrid
from matplotlib.pyplot import *
from pandas import DataFrame
@ -53,7 +53,9 @@ def predict_data(data, target, model, results):
return populated_results
def plot_roc_auc_curve(model, results):
def plot_roc_auc_curve(results):
fig = figure(figsize=(8, 6))
for model in results.index:
rounded_auc = round(results.loc[model]["auc"], 3)
plot(
results.loc[model]["fpr"],
@ -63,31 +65,29 @@ def plot_roc_auc_curve(model, results):
xticks(arange(0.0, 1.0, step=0.1))
yticks(arange(0.0, 1.0, step=0.1))
legend(loc="lower right")
def plot_confusion_matrix(model, results):
matrix = results.loc[model]["confusion_matrix"]
classes = ["Negative", "Positive"]
for item in matrix:
text(x=0.5, y=0.5, s=item)
xticks(ticks=arange(len(classes)), labels=classes)
yticks(ticks=arange(len(classes)), labels=classes)
def choose_plot_type(type, model, results):
if type == "roc":
plot_roc_auc_curve(model, results)
elif type == "confusion_matrix":
plot_confusion_matrix(model, results)
def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
fig = figure(figsize=(8, 6))
for model in results.index:
choose_plot_type(type, model, results)
xlabel(x_axis)
ylabel(y_axis)
fig_title = "ROC AUC curve"
title(fig_title)
xlabel("False positive rate")
ylabel("True positive rate")
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
def plot_confusion_matrix(results):
set_style("white")
matrix = results.filter(items=["model", "confusion_matrix"])
fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6))
for i in range(len(axes)):
heatmap(
ax=axes[i],
data=matrix.iloc[i]["confusion_matrix"],
cmap="Blues",
square=True,
annot=True,
cbar=False,
)
axes[i].set_title(matrix.index[i])
fig_title = "Confusion Matrix"
suptitle(fig_title)
show()
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
# TODO Add cross_val_score
def plot_all_figures(results):
set_theme()
plot_individual_figure(
results,
type="roc",
x_axis="False positive rate",
y_axis="True positive rate",
fig_title="ROC AUC curve",
)
plot_individual_figure(
results,
type="confusion_matrix",
x_axis="Predicted values",
y_axis="Real values",
fig_title="Confusion Matrix",
)
plot_roc_auc_curve(results=results)
plot_confusion_matrix(results=results)
def create_result_dataframes():