diff --git a/src/P2/processing.py b/src/P2/processing.py index 989fa3f..40e039b 100644 --- a/src/P2/processing.py +++ b/src/P2/processing.py @@ -3,18 +3,18 @@ from sys import argv from matplotlib.pyplot import * from pandas import DataFrame -from seaborn import heatmap, set_style, set_theme, pairplot +from seaborn import clustermap, set_style, set_theme, pairplot from sklearn.metrics import silhouette_score, calinski_harabasz_score from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN from preprocessing import parse_data, filter_dataframe -def choose_model(model): +def choose_model(model, cluster_number): if model == "kmeans": - return KMeans(random_state=42) + return KMeans(n_clusters=cluster_number, random_state=42) elif model == "birch": - return Birch() + return Birch(n_clusters=cluster_number) elif model == "spectral": return SpectralClustering() elif model == "meanshift": @@ -23,9 +23,9 @@ def choose_model(model): return DBSCAN() -def predict_data(data, model, results, sample): +def predict_data(data, model, cluster_number, results): model_name = model - model = choose_model(model) + model = choose_model(model=model, cluster_number=cluster_number) start_time = time.time() prediction = model.fit_predict(data) execution_time = time.time() - start_time @@ -35,7 +35,6 @@ def predict_data(data, model, results, sample): X=data, labels=prediction, metric="euclidean", - sample_size=sample, random_state=42, ) populated_results = populate_results( @@ -52,10 +51,13 @@ def predict_data(data, model, results, sample): def plot_heatmap(results): fig = figure(figsize=(20, 10)) - heatmap( - data=results, - cmap="Blues", - square=True, + results.reset_index() + matrix = results["prediction"] + print(matrix) + clustermap( + data=matrix, + cmap="mako", + metric="euclidean", annot=True, ) fig_title = "Heatmap" @@ -66,10 +68,10 @@ def plot_heatmap(results): def plot_scatter_plot(results): fig = figure(figsize=(20, 10)) - original_data = results.drop("prediction") + matrix = results.filter(items=["input", "prediction"]) pairplot( data=results, - vars=original_data, + vars=matrix, hue="prediction", palette="Paired", diag_kind="hist", @@ -138,12 +140,14 @@ def construct_case(df, choice): def usage(): - print("Usage: " + argv[0] + " ") + print("Usage: " + argv[0] + " ") print("preprocessing actions:") print("fill: fills the na values with the mean") print("drop: drops the na values") print("cases: choice of case study") - print("sample size: size of the sample when computing the Silhouette Coefficient") + print( + "number of clusters: number of clusters for the algorithms that use a fixed number" + ) exit() @@ -151,7 +155,7 @@ def main(): models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"] if len(argv) != 4: usage() - case, sample = argv[2], int(argv[3]) + case, cluster_number = argv[2], int(argv[3]) data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1])) individual_result, complete_results = create_result_dataframes() case_data = construct_case(df=data, choice=case) @@ -161,13 +165,13 @@ def main(): data=filtered_data, model=model, results=individual_result, - sample=sample, + cluster_number=cluster_number, ) complete_results = complete_results.append( individual_result.append(model_results) ) complete_results.set_index("model") - print(complete_results) + show_results(results=complete_results) if __name__ == "__main__":