Remove imputation of values from part 2
This commit is contained in:
parent
8b2ce6b5c9
commit
59895f4b8a
|
@ -89,7 +89,6 @@ def plot_confusion_matrix(results):
|
||||||
axes[i].set_title(matrix.index[i])
|
axes[i].set_title(matrix.index[i])
|
||||||
fig_title = "Confusion Matrix"
|
fig_title = "Confusion Matrix"
|
||||||
suptitle(fig_title)
|
suptitle(fig_title)
|
||||||
show()
|
|
||||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +105,6 @@ def plot_attributes_correlation(data, target):
|
||||||
axes[i].set_title(transformed_data.columns[i])
|
axes[i].set_title(transformed_data.columns[i])
|
||||||
fig_title = "Attribute's correlation"
|
fig_title = "Attribute's correlation"
|
||||||
suptitle(fig_title)
|
suptitle(fig_title)
|
||||||
show()
|
|
||||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,18 +8,6 @@ def replace_values(df):
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def process_na(df, action):
|
|
||||||
if action == "drop":
|
|
||||||
return df.dropna()
|
|
||||||
elif action == "fill":
|
|
||||||
return replace_values(df)
|
|
||||||
else:
|
|
||||||
print("Unknown action selected. The choices are: ")
|
|
||||||
print("fill: fills the na values with the mean")
|
|
||||||
print("drop: drops the na values")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
|
|
||||||
def filter_dataframe(df):
|
def filter_dataframe(df):
|
||||||
relevant_columns = [
|
relevant_columns = [
|
||||||
"TOT_HERIDOS_LEVES",
|
"TOT_HERIDOS_LEVES",
|
||||||
|
@ -39,8 +27,8 @@ def normalize_numerical_values(df):
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def parse_data(source, action):
|
def parse_data(source):
|
||||||
df = read_csv(filepath_or_buffer=source, na_values="?")
|
df = read_csv(filepath_or_buffer=source, na_values="?")
|
||||||
processed_df = process_na(df=df, action=action)
|
processed_df = df.dropna()
|
||||||
normalized_df = normalize_numerical_values(df=processed_df)
|
normalized_df = normalize_numerical_values(df=processed_df)
|
||||||
return normalized_df
|
return normalized_df
|
||||||
|
|
|
@ -3,7 +3,6 @@ from sys import argv
|
||||||
|
|
||||||
from matplotlib.pyplot import *
|
from matplotlib.pyplot import *
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from seaborn import clustermap, set_style, set_theme, pairplot
|
|
||||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
||||||
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
|
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
|
||||||
|
|
||||||
|
@ -49,47 +48,6 @@ def predict_data(data, model, cluster_number, results):
|
||||||
return populated_results
|
return populated_results
|
||||||
|
|
||||||
|
|
||||||
def plot_heatmap(results):
|
|
||||||
fig = figure(figsize=(20, 10))
|
|
||||||
results.reset_index()
|
|
||||||
matrix = results["prediction"]
|
|
||||||
print(matrix)
|
|
||||||
clustermap(
|
|
||||||
data=matrix,
|
|
||||||
cmap="mako",
|
|
||||||
metric="euclidean",
|
|
||||||
annot=True,
|
|
||||||
)
|
|
||||||
fig_title = "Heatmap"
|
|
||||||
title(fig_title)
|
|
||||||
show()
|
|
||||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
|
||||||
|
|
||||||
|
|
||||||
def plot_scatter_plot(results):
|
|
||||||
fig = figure(figsize=(20, 10))
|
|
||||||
matrix = results.filter(items=["input", "prediction"])
|
|
||||||
pairplot(
|
|
||||||
data=results,
|
|
||||||
vars=matrix,
|
|
||||||
hue="prediction",
|
|
||||||
palette="Paired",
|
|
||||||
diag_kind="hist",
|
|
||||||
)
|
|
||||||
fig_title = "Scatter plot"
|
|
||||||
title(fig_title)
|
|
||||||
show()
|
|
||||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
|
||||||
|
|
||||||
|
|
||||||
def show_results(results):
|
|
||||||
set_theme()
|
|
||||||
set_style("white")
|
|
||||||
plot_heatmap(results=results)
|
|
||||||
plot_scatter_plot(results=results)
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
|
|
||||||
def create_result_dataframes():
|
def create_result_dataframes():
|
||||||
results = DataFrame(
|
results = DataFrame(
|
||||||
columns=[
|
columns=[
|
||||||
|
@ -153,10 +111,10 @@ def usage():
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
|
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
|
||||||
if len(argv) != 4:
|
if len(argv) != 3:
|
||||||
usage()
|
usage()
|
||||||
case, cluster_number = argv[2], int(argv[3])
|
case, cluster_number = argv[1], int(argv[2])
|
||||||
data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
|
data = parse_data(source="data/accidentes_2013.csv")
|
||||||
individual_result, complete_results = create_result_dataframes()
|
individual_result, complete_results = create_result_dataframes()
|
||||||
case_data = construct_case(df=data, choice=case)
|
case_data = construct_case(df=data, choice=case)
|
||||||
filtered_data = filter_dataframe(df=case_data)
|
filtered_data = filter_dataframe(df=case_data)
|
||||||
|
@ -171,7 +129,7 @@ def main():
|
||||||
individual_result.append(model_results)
|
individual_result.append(model_results)
|
||||||
)
|
)
|
||||||
complete_results.set_index("model")
|
complete_results.set_index("model")
|
||||||
show_results(results=complete_results)
|
print(complete_results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue