Replace AffinityPropagation with SpectralClustering

This commit is contained in:
coolneng 2020-12-11 18:58:50 +01:00
parent 7e356fed37
commit b4e90c1174
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 9 additions and 14 deletions

View File

@ -5,7 +5,7 @@ from matplotlib.pyplot import *
from pandas import DataFrame from pandas import DataFrame
from seaborn import heatmap, set_style, set_theme, pairplot from seaborn import heatmap, set_style, set_theme, pairplot
from sklearn.metrics import silhouette_score, calinski_harabasz_score from sklearn.metrics import silhouette_score, calinski_harabasz_score
from sklearn.cluster import KMeans, Birch, AffinityPropagation, MeanShift, DBSCAN from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
from preprocessing import parse_data, filter_dataframe from preprocessing import parse_data, filter_dataframe
@ -15,8 +15,8 @@ def choose_model(model):
return KMeans(random_state=42) return KMeans(random_state=42)
elif model == "birch": elif model == "birch":
return Birch() return Birch()
elif model == "affinity": elif model == "spectral":
return AffinityPropagation(random_state=42) return SpectralClustering()
elif model == "meanshift": elif model == "meanshift":
return MeanShift() return MeanShift()
elif model == "dbscan": elif model == "dbscan":
@ -29,6 +29,7 @@ def predict_data(data, model, results, sample):
start_time = time.time() start_time = time.time()
prediction = model.fit_predict(data) prediction = model.fit_predict(data)
execution_time = time.time() - start_time execution_time = time.time() - start_time
cluster_number = len(set(prediction))
calinski = calinski_harabasz_score(X=data, labels=prediction) calinski = calinski_harabasz_score(X=data, labels=prediction)
silhouette = silhouette_score( silhouette = silhouette_score(
X=data, X=data,
@ -41,7 +42,7 @@ def predict_data(data, model, results, sample):
df=results, df=results,
model=model_name, model=model_name,
prediction=prediction, prediction=prediction,
clusters=len(set(prediction)), clusters=cluster_number,
calinski=calinski, calinski=calinski,
silhouette=silhouette, silhouette=silhouette,
time=execution_time, time=execution_time,
@ -79,17 +80,12 @@ def plot_scatter_plot(results):
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png") fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
def print_dataframe(df):
df.set_index("model")
print(df)
def show_results(results): def show_results(results):
set_theme() set_theme()
set_style("white") set_style("white")
plot_heatmap(results=results) plot_heatmap(results=results)
plot_scatter_plot(results=results) plot_scatter_plot(results=results)
print_dataframe(df=results) print(results)
def create_result_dataframes(): def create_result_dataframes():
@ -103,8 +99,7 @@ def create_result_dataframes():
"time", "time",
] ]
) )
indexed_results = results.set_index("model") return results, results
return indexed_results, indexed_results
def populate_results(df, model, clusters, prediction, calinski, silhouette, time): def populate_results(df, model, clusters, prediction, calinski, silhouette, time):
@ -153,7 +148,7 @@ def usage():
def main(): def main():
models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"] models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
if len(argv) != 4: if len(argv) != 4:
usage() usage()
case, sample = argv[2], int(argv[3]) case, sample = argv[2], int(argv[3])
@ -172,7 +167,7 @@ def main():
individual_result.append(model_results) individual_result.append(model_results)
) )
complete_results.set_index("model") complete_results.set_index("model")
print_dataframe(df=complete_results) print(complete_results)
if __name__ == "__main__": if __name__ == "__main__":