Replace AffinityPropagation with SpectralClustering
This commit is contained in:
parent
7e356fed37
commit
b4e90c1174
|
@ -5,7 +5,7 @@ from matplotlib.pyplot import *
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from seaborn import heatmap, set_style, set_theme, pairplot
|
from seaborn import heatmap, set_style, set_theme, pairplot
|
||||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
||||||
from sklearn.cluster import KMeans, Birch, AffinityPropagation, MeanShift, DBSCAN
|
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
|
||||||
|
|
||||||
from preprocessing import parse_data, filter_dataframe
|
from preprocessing import parse_data, filter_dataframe
|
||||||
|
|
||||||
|
@ -15,8 +15,8 @@ def choose_model(model):
|
||||||
return KMeans(random_state=42)
|
return KMeans(random_state=42)
|
||||||
elif model == "birch":
|
elif model == "birch":
|
||||||
return Birch()
|
return Birch()
|
||||||
elif model == "affinity":
|
elif model == "spectral":
|
||||||
return AffinityPropagation(random_state=42)
|
return SpectralClustering()
|
||||||
elif model == "meanshift":
|
elif model == "meanshift":
|
||||||
return MeanShift()
|
return MeanShift()
|
||||||
elif model == "dbscan":
|
elif model == "dbscan":
|
||||||
|
@ -29,6 +29,7 @@ def predict_data(data, model, results, sample):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
prediction = model.fit_predict(data)
|
prediction = model.fit_predict(data)
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
|
cluster_number = len(set(prediction))
|
||||||
calinski = calinski_harabasz_score(X=data, labels=prediction)
|
calinski = calinski_harabasz_score(X=data, labels=prediction)
|
||||||
silhouette = silhouette_score(
|
silhouette = silhouette_score(
|
||||||
X=data,
|
X=data,
|
||||||
|
@ -41,7 +42,7 @@ def predict_data(data, model, results, sample):
|
||||||
df=results,
|
df=results,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
prediction=prediction,
|
prediction=prediction,
|
||||||
clusters=len(set(prediction)),
|
clusters=cluster_number,
|
||||||
calinski=calinski,
|
calinski=calinski,
|
||||||
silhouette=silhouette,
|
silhouette=silhouette,
|
||||||
time=execution_time,
|
time=execution_time,
|
||||||
|
@ -79,17 +80,12 @@ def plot_scatter_plot(results):
|
||||||
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
|
||||||
|
|
||||||
|
|
||||||
def print_dataframe(df):
|
|
||||||
df.set_index("model")
|
|
||||||
print(df)
|
|
||||||
|
|
||||||
|
|
||||||
def show_results(results):
|
def show_results(results):
|
||||||
set_theme()
|
set_theme()
|
||||||
set_style("white")
|
set_style("white")
|
||||||
plot_heatmap(results=results)
|
plot_heatmap(results=results)
|
||||||
plot_scatter_plot(results=results)
|
plot_scatter_plot(results=results)
|
||||||
print_dataframe(df=results)
|
print(results)
|
||||||
|
|
||||||
|
|
||||||
def create_result_dataframes():
|
def create_result_dataframes():
|
||||||
|
@ -103,8 +99,7 @@ def create_result_dataframes():
|
||||||
"time",
|
"time",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
indexed_results = results.set_index("model")
|
return results, results
|
||||||
return indexed_results, indexed_results
|
|
||||||
|
|
||||||
|
|
||||||
def populate_results(df, model, clusters, prediction, calinski, silhouette, time):
|
def populate_results(df, model, clusters, prediction, calinski, silhouette, time):
|
||||||
|
@ -153,7 +148,7 @@ def usage():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"]
|
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
|
||||||
if len(argv) != 4:
|
if len(argv) != 4:
|
||||||
usage()
|
usage()
|
||||||
case, sample = argv[2], int(argv[3])
|
case, sample = argv[2], int(argv[3])
|
||||||
|
@ -172,7 +167,7 @@ def main():
|
||||||
individual_result.append(model_results)
|
individual_result.append(model_results)
|
||||||
)
|
)
|
||||||
complete_results.set_index("model")
|
complete_results.set_index("model")
|
||||||
print_dataframe(df=complete_results)
|
print(complete_results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue