Plot confusion matrix and simplify plot selection
This commit is contained in:
parent
bed9f1d0e9
commit
4af2c35c2e
|
@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier
|
|||
from sklearn.preprocessing import scale
|
||||
from sklearn.svm import LinearSVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from seaborn import set_theme
|
||||
from seaborn import set_theme, set_style, heatmap, FacetGrid
|
||||
from matplotlib.pyplot import *
|
||||
from pandas import DataFrame
|
||||
|
||||
|
@ -53,7 +53,9 @@ def predict_data(data, target, model, results):
|
|||
return populated_results
|
||||
|
||||
|
||||
def plot_roc_auc_curve(model, results):
|
||||
def plot_roc_auc_curve(results):
|
||||
fig = figure(figsize=(8, 6))
|
||||
for model in results.index:
|
||||
rounded_auc = round(results.loc[model]["auc"], 3)
|
||||
plot(
|
||||
results.loc[model]["fpr"],
|
||||
|
@ -63,31 +65,29 @@ def plot_roc_auc_curve(model, results):
|
|||
xticks(arange(0.0, 1.0, step=0.1))
|
||||
yticks(arange(0.0, 1.0, step=0.1))
|
||||
legend(loc="lower right")
|
||||
|
||||
|
||||
def plot_confusion_matrix(model, results):
|
||||
matrix = results.loc[model]["confusion_matrix"]
|
||||
classes = ["Negative", "Positive"]
|
||||
for item in matrix:
|
||||
text(x=0.5, y=0.5, s=item)
|
||||
xticks(ticks=arange(len(classes)), labels=classes)
|
||||
yticks(ticks=arange(len(classes)), labels=classes)
|
||||
|
||||
|
||||
def choose_plot_type(type, model, results):
|
||||
if type == "roc":
|
||||
plot_roc_auc_curve(model, results)
|
||||
elif type == "confusion_matrix":
|
||||
plot_confusion_matrix(model, results)
|
||||
|
||||
|
||||
def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
|
||||
fig = figure(figsize=(8, 6))
|
||||
for model in results.index:
|
||||
choose_plot_type(type, model, results)
|
||||
xlabel(x_axis)
|
||||
ylabel(y_axis)
|
||||
fig_title = "ROC AUC curve"
|
||||
title(fig_title)
|
||||
xlabel("False positive rate")
|
||||
ylabel("True positive rate")
|
||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||
|
||||
|
||||
def plot_confusion_matrix(results):
|
||||
set_style("white")
|
||||
matrix = results.filter(items=["model", "confusion_matrix"])
|
||||
fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6))
|
||||
for i in range(len(axes)):
|
||||
heatmap(
|
||||
ax=axes[i],
|
||||
data=matrix.iloc[i]["confusion_matrix"],
|
||||
cmap="Blues",
|
||||
square=True,
|
||||
annot=True,
|
||||
cbar=False,
|
||||
)
|
||||
axes[i].set_title(matrix.index[i])
|
||||
fig_title = "Confusion Matrix"
|
||||
suptitle(fig_title)
|
||||
show()
|
||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||
|
||||
|
@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
|
|||
# TODO Add cross_val_score
|
||||
def plot_all_figures(results):
|
||||
set_theme()
|
||||
plot_individual_figure(
|
||||
results,
|
||||
type="roc",
|
||||
x_axis="False positive rate",
|
||||
y_axis="True positive rate",
|
||||
fig_title="ROC AUC curve",
|
||||
)
|
||||
plot_individual_figure(
|
||||
results,
|
||||
type="confusion_matrix",
|
||||
x_axis="Predicted values",
|
||||
y_axis="Real values",
|
||||
fig_title="Confusion Matrix",
|
||||
)
|
||||
plot_roc_auc_curve(results=results)
|
||||
plot_confusion_matrix(results=results)
|
||||
|
||||
|
||||
def create_result_dataframes():
|
||||
|
|
Loading…
Reference in New Issue