Plot confusion matrix and simplify plot selection
This commit is contained in:
parent
bed9f1d0e9
commit
4af2c35c2e
|
@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier
|
||||||
from sklearn.preprocessing import scale
|
from sklearn.preprocessing import scale
|
||||||
from sklearn.svm import LinearSVC
|
from sklearn.svm import LinearSVC
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from seaborn import set_theme
|
from seaborn import set_theme, set_style, heatmap, FacetGrid
|
||||||
from matplotlib.pyplot import *
|
from matplotlib.pyplot import *
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
@ -53,7 +53,9 @@ def predict_data(data, target, model, results):
|
||||||
return populated_results
|
return populated_results
|
||||||
|
|
||||||
|
|
||||||
def plot_roc_auc_curve(model, results):
|
def plot_roc_auc_curve(results):
|
||||||
|
fig = figure(figsize=(8, 6))
|
||||||
|
for model in results.index:
|
||||||
rounded_auc = round(results.loc[model]["auc"], 3)
|
rounded_auc = round(results.loc[model]["auc"], 3)
|
||||||
plot(
|
plot(
|
||||||
results.loc[model]["fpr"],
|
results.loc[model]["fpr"],
|
||||||
|
@ -63,31 +65,29 @@ def plot_roc_auc_curve(model, results):
|
||||||
xticks(arange(0.0, 1.0, step=0.1))
|
xticks(arange(0.0, 1.0, step=0.1))
|
||||||
yticks(arange(0.0, 1.0, step=0.1))
|
yticks(arange(0.0, 1.0, step=0.1))
|
||||||
legend(loc="lower right")
|
legend(loc="lower right")
|
||||||
|
fig_title = "ROC AUC curve"
|
||||||
|
|
||||||
def plot_confusion_matrix(model, results):
|
|
||||||
matrix = results.loc[model]["confusion_matrix"]
|
|
||||||
classes = ["Negative", "Positive"]
|
|
||||||
for item in matrix:
|
|
||||||
text(x=0.5, y=0.5, s=item)
|
|
||||||
xticks(ticks=arange(len(classes)), labels=classes)
|
|
||||||
yticks(ticks=arange(len(classes)), labels=classes)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_plot_type(type, model, results):
|
|
||||||
if type == "roc":
|
|
||||||
plot_roc_auc_curve(model, results)
|
|
||||||
elif type == "confusion_matrix":
|
|
||||||
plot_confusion_matrix(model, results)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
|
|
||||||
fig = figure(figsize=(8, 6))
|
|
||||||
for model in results.index:
|
|
||||||
choose_plot_type(type, model, results)
|
|
||||||
xlabel(x_axis)
|
|
||||||
ylabel(y_axis)
|
|
||||||
title(fig_title)
|
title(fig_title)
|
||||||
|
xlabel("False positive rate")
|
||||||
|
ylabel("True positive rate")
|
||||||
|
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_confusion_matrix(results):
|
||||||
|
set_style("white")
|
||||||
|
matrix = results.filter(items=["model", "confusion_matrix"])
|
||||||
|
fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6))
|
||||||
|
for i in range(len(axes)):
|
||||||
|
heatmap(
|
||||||
|
ax=axes[i],
|
||||||
|
data=matrix.iloc[i]["confusion_matrix"],
|
||||||
|
cmap="Blues",
|
||||||
|
square=True,
|
||||||
|
annot=True,
|
||||||
|
cbar=False,
|
||||||
|
)
|
||||||
|
axes[i].set_title(matrix.index[i])
|
||||||
|
fig_title = "Confusion Matrix"
|
||||||
|
suptitle(fig_title)
|
||||||
show()
|
show()
|
||||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||||
|
|
||||||
|
@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
|
||||||
# TODO Add cross_val_score
|
# TODO Add cross_val_score
|
||||||
def plot_all_figures(results):
|
def plot_all_figures(results):
|
||||||
set_theme()
|
set_theme()
|
||||||
plot_individual_figure(
|
plot_roc_auc_curve(results=results)
|
||||||
results,
|
plot_confusion_matrix(results=results)
|
||||||
type="roc",
|
|
||||||
x_axis="False positive rate",
|
|
||||||
y_axis="True positive rate",
|
|
||||||
fig_title="ROC AUC curve",
|
|
||||||
)
|
|
||||||
plot_individual_figure(
|
|
||||||
results,
|
|
||||||
type="confusion_matrix",
|
|
||||||
x_axis="Predicted values",
|
|
||||||
y_axis="Real values",
|
|
||||||
fig_title="Confusion Matrix",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_result_dataframes():
|
def create_result_dataframes():
|
||||||
|
|
Loading…
Reference in New Issue