Replace the samples argument with cluster number

This commit is contained in:
coolneng 2020-12-13 18:05:22 +01:00
parent b4e90c1174
commit e63406c0a8
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 22 additions and 18 deletions

View File

@ -3,18 +3,18 @@ from sys import argv
from matplotlib.pyplot import *
from pandas import DataFrame
from seaborn import heatmap, set_style, set_theme, pairplot
from seaborn import clustermap, set_style, set_theme, pairplot
from sklearn.metrics import silhouette_score, calinski_harabasz_score
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
from preprocessing import parse_data, filter_dataframe
def choose_model(model):
def choose_model(model, cluster_number):
if model == "kmeans":
return KMeans(random_state=42)
return KMeans(n_clusters=cluster_number, random_state=42)
elif model == "birch":
return Birch()
return Birch(n_clusters=cluster_number)
elif model == "spectral":
return SpectralClustering()
elif model == "meanshift":
@ -23,9 +23,9 @@ def choose_model(model):
return DBSCAN()
def predict_data(data, model, results, sample):
def predict_data(data, model, cluster_number, results):
model_name = model
model = choose_model(model)
model = choose_model(model=model, cluster_number=cluster_number)
start_time = time.time()
prediction = model.fit_predict(data)
execution_time = time.time() - start_time
@ -35,7 +35,6 @@ def predict_data(data, model, results, sample):
X=data,
labels=prediction,
metric="euclidean",
sample_size=sample,
random_state=42,
)
populated_results = populate_results(
@ -52,10 +51,13 @@ def predict_data(data, model, results, sample):
def plot_heatmap(results):
fig = figure(figsize=(20, 10))
heatmap(
data=results,
cmap="Blues",
square=True,
results.reset_index()
matrix = results["prediction"]
print(matrix)
clustermap(
data=matrix,
cmap="mako",
metric="euclidean",
annot=True,
)
fig_title = "Heatmap"
@ -66,10 +68,10 @@ def plot_heatmap(results):
def plot_scatter_plot(results):
fig = figure(figsize=(20, 10))
original_data = results.drop("prediction")
matrix = results.filter(items=["input", "prediction"])
pairplot(
data=results,
vars=original_data,
vars=matrix,
hue="prediction",
palette="Paired",
diag_kind="hist",
@ -138,12 +140,14 @@ def construct_case(df, choice):
def usage():
print("Usage: " + argv[0] + "<preprocessing action> <case> <sample size>")
print("Usage: " + argv[0] + "<preprocessing action> <case> <number of clusters>")
print("preprocessing actions:")
print("fill: fills the na values with the mean")
print("drop: drops the na values")
print("cases: choice of case study")
print("sample size: size of the sample when computing the Silhouette Coefficient")
print(
"number of clusters: number of clusters for the algorithms that use a fixed number"
)
exit()
@ -151,7 +155,7 @@ def main():
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
if len(argv) != 4:
usage()
case, sample = argv[2], int(argv[3])
case, cluster_number = argv[2], int(argv[3])
data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
individual_result, complete_results = create_result_dataframes()
case_data = construct_case(df=data, choice=case)
@ -161,13 +165,13 @@ def main():
data=filtered_data,
model=model,
results=individual_result,
sample=sample,
cluster_number=cluster_number,
)
complete_results = complete_results.append(
individual_result.append(model_results)
)
complete_results.set_index("model")
print(complete_results)
show_results(results=complete_results)
if __name__ == "__main__":