From a3798a781f57d72b8a4cc4c3254fad73e30c564a Mon Sep 17 00:00:00 2001 From: coolneng Date: Wed, 9 Dec 2020 22:56:45 +0100 Subject: [PATCH] Fix typos --- src/P2/preprocessing.py | 11 +++++++++++ src/P2/processing.py | 12 +++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/P2/preprocessing.py b/src/P2/preprocessing.py index 6dbf838..c780168 100644 --- a/src/P2/preprocessing.py +++ b/src/P2/preprocessing.py @@ -38,6 +38,17 @@ def filter_dataframe(df) -> DataFrame: return filtered_df +def choose_numerical_values(df): + cols = [ + "TOT_HERIDOS_LEVES", + "TOT_HERIDOS_GRAVES", + "TOT_VEHICULOS_IMPLICADOS", + "TOT_MUERTOS", + ] + filtered_df = df.filter(items=cols) + return filtered_df + + def normalize_numerical_values(df) -> DataFrame: cols = [ "TOT_HERIDOS_LEVES", diff --git a/src/P2/processing.py b/src/P2/processing.py index 02bd9bf..1b91a52 100644 --- a/src/P2/processing.py +++ b/src/P2/processing.py @@ -8,7 +8,7 @@ from seaborn import heatmap, set_style, set_theme, pairplot from sklearn.metrics import silhouette_score, calinski_harabasz_score from sklearn.cluster import KMeans, Birch, AffinityPropagation, MeanShift, DBSCAN -from preprocessing import parse_data +from preprocessing import parse_data, choose_numerical_values def choose_model( @@ -29,11 +29,12 @@ def choose_model( def predict_data(data, model, results, sample) -> DataFrame: model = choose_model(model) start_time = time.time() - prediction = model.fit_predict(data) + numerical_data = choose_numerical_values(df=data) + prediction = model.fit_predict(numerical_data) execution_time = time.time() - start_time - calinski = calinski_harabasz_score(X=data, labels=prediction) + calinski = calinski_harabasz_score(X=numerical_data, labels=prediction) silhouette = silhouette_score( - X=data, + X=numerical_data, labels=prediction, metric="euclidean", sample_size=sample, @@ -130,7 +131,8 @@ def populate_results( def rename_model(model) -> str: short_name = ["kmeans", "birch", "affinity", "meanshift", "dbscan"] models = [ - "KMean(random_state=42)", + "KMeans(random_state=42)", + "Birch()", "AffinityPropagation(random_state=42)", "MeanShift()", "DBSCAN()",