Replace the samples argument with cluster number

This commit is contained in:
coolneng 2020-12-13 18:05:22 +01:00
parent b4e90c1174
commit e63406c0a8
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 22 additions and 18 deletions

View File

@ -3,18 +3,18 @@ from sys import argv
from matplotlib.pyplot import * from matplotlib.pyplot import *
from pandas import DataFrame from pandas import DataFrame
from seaborn import heatmap, set_style, set_theme, pairplot from seaborn import clustermap, set_style, set_theme, pairplot
from sklearn.metrics import silhouette_score, calinski_harabasz_score from sklearn.metrics import silhouette_score, calinski_harabasz_score
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
from preprocessing import parse_data, filter_dataframe from preprocessing import parse_data, filter_dataframe
def choose_model(model): def choose_model(model, cluster_number):
if model == "kmeans": if model == "kmeans":
return KMeans(random_state=42) return KMeans(n_clusters=cluster_number, random_state=42)
elif model == "birch": elif model == "birch":
return Birch() return Birch(n_clusters=cluster_number)
elif model == "spectral": elif model == "spectral":
return SpectralClustering() return SpectralClustering()
elif model == "meanshift": elif model == "meanshift":
@ -23,9 +23,9 @@ def choose_model(model):
return DBSCAN() return DBSCAN()
def predict_data(data, model, results, sample): def predict_data(data, model, cluster_number, results):
model_name = model model_name = model
model = choose_model(model) model = choose_model(model=model, cluster_number=cluster_number)
start_time = time.time() start_time = time.time()
prediction = model.fit_predict(data) prediction = model.fit_predict(data)
execution_time = time.time() - start_time execution_time = time.time() - start_time
@ -35,7 +35,6 @@ def predict_data(data, model, results, sample):
X=data, X=data,
labels=prediction, labels=prediction,
metric="euclidean", metric="euclidean",
sample_size=sample,
random_state=42, random_state=42,
) )
populated_results = populate_results( populated_results = populate_results(
@ -52,10 +51,13 @@ def predict_data(data, model, results, sample):
def plot_heatmap(results): def plot_heatmap(results):
fig = figure(figsize=(20, 10)) fig = figure(figsize=(20, 10))
heatmap( results.reset_index()
data=results, matrix = results["prediction"]
cmap="Blues", print(matrix)
square=True, clustermap(
data=matrix,
cmap="mako",
metric="euclidean",
annot=True, annot=True,
) )
fig_title = "Heatmap" fig_title = "Heatmap"
@ -66,10 +68,10 @@ def plot_heatmap(results):
def plot_scatter_plot(results): def plot_scatter_plot(results):
fig = figure(figsize=(20, 10)) fig = figure(figsize=(20, 10))
original_data = results.drop("prediction") matrix = results.filter(items=["input", "prediction"])
pairplot( pairplot(
data=results, data=results,
vars=original_data, vars=matrix,
hue="prediction", hue="prediction",
palette="Paired", palette="Paired",
diag_kind="hist", diag_kind="hist",
@ -138,12 +140,14 @@ def construct_case(df, choice):
def usage(): def usage():
print("Usage: " + argv[0] + "<preprocessing action> <case> <sample size>") print("Usage: " + argv[0] + "<preprocessing action> <case> <number of clusters>")
print("preprocessing actions:") print("preprocessing actions:")
print("fill: fills the na values with the mean") print("fill: fills the na values with the mean")
print("drop: drops the na values") print("drop: drops the na values")
print("cases: choice of case study") print("cases: choice of case study")
print("sample size: size of the sample when computing the Silhouette Coefficient") print(
"number of clusters: number of clusters for the algorithms that use a fixed number"
)
exit() exit()
@ -151,7 +155,7 @@ def main():
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"] models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
if len(argv) != 4: if len(argv) != 4:
usage() usage()
case, sample = argv[2], int(argv[3]) case, cluster_number = argv[2], int(argv[3])
data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1])) data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
individual_result, complete_results = create_result_dataframes() individual_result, complete_results = create_result_dataframes()
case_data = construct_case(df=data, choice=case) case_data = construct_case(df=data, choice=case)
@ -161,13 +165,13 @@ def main():
data=filtered_data, data=filtered_data,
model=model, model=model,
results=individual_result, results=individual_result,
sample=sample, cluster_number=cluster_number,
) )
complete_results = complete_results.append( complete_results = complete_results.append(
individual_result.append(model_results) individual_result.append(model_results)
) )
complete_results.set_index("model") complete_results.set_index("model")
print(complete_results) show_results(results=complete_results)
if __name__ == "__main__": if __name__ == "__main__":