Plot attribute correlation
This commit is contained in:
parent
4af2c35c2e
commit
ef043650b5
|
@ -1,15 +1,14 @@
|
|||
from numpy import mean, arange
|
||||
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from sklearn.preprocessing import scale
|
||||
from sklearn.svm import LinearSVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from seaborn import set_theme, set_style, heatmap, FacetGrid
|
||||
from seaborn import set_theme, set_style, heatmap, countplot
|
||||
from matplotlib.pyplot import *
|
||||
from pandas import DataFrame
|
||||
from pandas import DataFrame, cut
|
||||
|
||||
from sys import argv
|
||||
|
||||
|
@ -92,11 +91,28 @@ def plot_confusion_matrix(results):
|
|||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||
|
||||
|
||||
# TODO Add cross_val_score
|
||||
def plot_all_figures(results):
|
||||
def plot_attributes_correlation(data, target):
|
||||
transformed_data = transform_dataframe(data, target)
|
||||
fig, axes = subplots(nrows=5, ncols=1, figsize=(8, 6))
|
||||
for i in range(len(axes)):
|
||||
countplot(
|
||||
ax=axes[i],
|
||||
x=transformed_data.columns[i],
|
||||
data=transformed_data,
|
||||
hue="Severity",
|
||||
)
|
||||
axes[i].set_title(transformed_data.columns[i])
|
||||
fig_title = "Attribute's correlation"
|
||||
suptitle(fig_title)
|
||||
show()
|
||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||
|
||||
|
||||
def plot_all_figures(results, data, target):
|
||||
set_theme()
|
||||
plot_roc_auc_curve(results=results)
|
||||
plot_confusion_matrix(results=results)
|
||||
# plot_roc_auc_curve(results=results)
|
||||
# plot_confusion_matrix(results=results)
|
||||
plot_attributes_correlation(data=data, target=target)
|
||||
|
||||
|
||||
def create_result_dataframes():
|
||||
|
@ -127,6 +143,13 @@ def rename_model(model):
|
|||
return mapping[model]
|
||||
|
||||
|
||||
def transform_dataframe(data, target):
|
||||
joined_df = data.join(target)
|
||||
binned_df = joined_df.copy()
|
||||
binned_df["Age"] = cut(x=joined_df["Age"], bins=[15, 30, 45, 60, 75])
|
||||
return binned_df
|
||||
|
||||
|
||||
def usage():
|
||||
print("Usage: " + argv[0] + "<preprocessing action>")
|
||||
print("preprocessing actions:")
|
||||
|
@ -149,7 +172,7 @@ def main():
|
|||
individual_result.append(model_results)
|
||||
)
|
||||
indexed_results = complete_results.set_index("model")
|
||||
plot_all_figures(results=indexed_results)
|
||||
plot_all_figures(results=indexed_results, data=data, target=target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue