diff --git a/src/P2/processing.py b/src/P2/processing.py index cb64fa2..989fa3f 100644 --- a/src/P2/processing.py +++ b/src/P2/processing.py @@ -5,7 +5,7 @@ from matplotlib.pyplot import * from pandas import DataFrame from seaborn import heatmap, set_style, set_theme, pairplot from sklearn.metrics import silhouette_score, calinski_harabasz_score -from sklearn.cluster import KMeans, Birch, AffinityPropagation, MeanShift, DBSCAN +from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN from preprocessing import parse_data, filter_dataframe @@ -15,8 +15,8 @@ def choose_model(model): return KMeans(random_state=42) elif model == "birch": return Birch() - elif model == "affinity": - return AffinityPropagation(random_state=42) + elif model == "spectral": + return SpectralClustering() elif model == "meanshift": return MeanShift() elif model == "dbscan": @@ -29,6 +29,7 @@ def predict_data(data, model, results, sample): start_time = time.time() prediction = model.fit_predict(data) execution_time = time.time() - start_time + cluster_number = len(set(prediction)) calinski = calinski_harabasz_score(X=data, labels=prediction) silhouette = silhouette_score( X=data, @@ -41,7 +42,7 @@ def predict_data(data, model, results, sample): df=results, model=model_name, prediction=prediction, - clusters=len(set(prediction)), + clusters=cluster_number, calinski=calinski, silhouette=silhouette, time=execution_time, @@ -79,17 +80,12 @@ def plot_scatter_plot(results): fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") -def print_dataframe(df): - df.set_index("model") - print(df) - - def show_results(results): set_theme() set_style("white") plot_heatmap(results=results) plot_scatter_plot(results=results) - print_dataframe(df=results) + print(results) def create_result_dataframes(): @@ -103,8 +99,7 @@ def create_result_dataframes(): "time", ] ) - indexed_results = results.set_index("model") - return indexed_results, indexed_results + return results, results def populate_results(df, model, clusters, prediction, calinski, silhouette, time): @@ -153,7 +148,7 @@ def usage(): def main(): - models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"] + models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"] if len(argv) != 4: usage() case, sample = argv[2], int(argv[3]) @@ -172,7 +167,7 @@ def main(): individual_result.append(model_results) ) complete_results.set_index("model") - print_dataframe(df=complete_results) + print(complete_results) if __name__ == "__main__":