Replace the samples argument with cluster number
This commit is contained in:
parent
b4e90c1174
commit
e63406c0a8
|
@ -3,18 +3,18 @@ from sys import argv
|
||||||
|
|
||||||
from matplotlib.pyplot import *
|
from matplotlib.pyplot import *
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from seaborn import heatmap, set_style, set_theme, pairplot
|
from seaborn import clustermap, set_style, set_theme, pairplot
|
||||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
from sklearn.metrics import silhouette_score, calinski_harabasz_score
|
||||||
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
|
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
|
||||||
|
|
||||||
from preprocessing import parse_data, filter_dataframe
|
from preprocessing import parse_data, filter_dataframe
|
||||||
|
|
||||||
|
|
||||||
def choose_model(model):
|
def choose_model(model, cluster_number):
|
||||||
if model == "kmeans":
|
if model == "kmeans":
|
||||||
return KMeans(random_state=42)
|
return KMeans(n_clusters=cluster_number, random_state=42)
|
||||||
elif model == "birch":
|
elif model == "birch":
|
||||||
return Birch()
|
return Birch(n_clusters=cluster_number)
|
||||||
elif model == "spectral":
|
elif model == "spectral":
|
||||||
return SpectralClustering()
|
return SpectralClustering()
|
||||||
elif model == "meanshift":
|
elif model == "meanshift":
|
||||||
|
@ -23,9 +23,9 @@ def choose_model(model):
|
||||||
return DBSCAN()
|
return DBSCAN()
|
||||||
|
|
||||||
|
|
||||||
def predict_data(data, model, results, sample):
|
def predict_data(data, model, cluster_number, results):
|
||||||
model_name = model
|
model_name = model
|
||||||
model = choose_model(model)
|
model = choose_model(model=model, cluster_number=cluster_number)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
prediction = model.fit_predict(data)
|
prediction = model.fit_predict(data)
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
|
@ -35,7 +35,6 @@ def predict_data(data, model, results, sample):
|
||||||
X=data,
|
X=data,
|
||||||
labels=prediction,
|
labels=prediction,
|
||||||
metric="euclidean",
|
metric="euclidean",
|
||||||
sample_size=sample,
|
|
||||||
random_state=42,
|
random_state=42,
|
||||||
)
|
)
|
||||||
populated_results = populate_results(
|
populated_results = populate_results(
|
||||||
|
@ -52,10 +51,13 @@ def predict_data(data, model, results, sample):
|
||||||
|
|
||||||
def plot_heatmap(results):
|
def plot_heatmap(results):
|
||||||
fig = figure(figsize=(20, 10))
|
fig = figure(figsize=(20, 10))
|
||||||
heatmap(
|
results.reset_index()
|
||||||
data=results,
|
matrix = results["prediction"]
|
||||||
cmap="Blues",
|
print(matrix)
|
||||||
square=True,
|
clustermap(
|
||||||
|
data=matrix,
|
||||||
|
cmap="mako",
|
||||||
|
metric="euclidean",
|
||||||
annot=True,
|
annot=True,
|
||||||
)
|
)
|
||||||
fig_title = "Heatmap"
|
fig_title = "Heatmap"
|
||||||
|
@ -66,10 +68,10 @@ def plot_heatmap(results):
|
||||||
|
|
||||||
def plot_scatter_plot(results):
|
def plot_scatter_plot(results):
|
||||||
fig = figure(figsize=(20, 10))
|
fig = figure(figsize=(20, 10))
|
||||||
original_data = results.drop("prediction")
|
matrix = results.filter(items=["input", "prediction"])
|
||||||
pairplot(
|
pairplot(
|
||||||
data=results,
|
data=results,
|
||||||
vars=original_data,
|
vars=matrix,
|
||||||
hue="prediction",
|
hue="prediction",
|
||||||
palette="Paired",
|
palette="Paired",
|
||||||
diag_kind="hist",
|
diag_kind="hist",
|
||||||
|
@ -138,12 +140,14 @@ def construct_case(df, choice):
|
||||||
|
|
||||||
|
|
||||||
def usage():
|
def usage():
|
||||||
print("Usage: " + argv[0] + "<preprocessing action> <case> <sample size>")
|
print("Usage: " + argv[0] + "<preprocessing action> <case> <number of clusters>")
|
||||||
print("preprocessing actions:")
|
print("preprocessing actions:")
|
||||||
print("fill: fills the na values with the mean")
|
print("fill: fills the na values with the mean")
|
||||||
print("drop: drops the na values")
|
print("drop: drops the na values")
|
||||||
print("cases: choice of case study")
|
print("cases: choice of case study")
|
||||||
print("sample size: size of the sample when computing the Silhouette Coefficient")
|
print(
|
||||||
|
"number of clusters: number of clusters for the algorithms that use a fixed number"
|
||||||
|
)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,7 +155,7 @@ def main():
|
||||||
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
|
models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
|
||||||
if len(argv) != 4:
|
if len(argv) != 4:
|
||||||
usage()
|
usage()
|
||||||
case, sample = argv[2], int(argv[3])
|
case, cluster_number = argv[2], int(argv[3])
|
||||||
data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
|
data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
|
||||||
individual_result, complete_results = create_result_dataframes()
|
individual_result, complete_results = create_result_dataframes()
|
||||||
case_data = construct_case(df=data, choice=case)
|
case_data = construct_case(df=data, choice=case)
|
||||||
|
@ -161,13 +165,13 @@ def main():
|
||||||
data=filtered_data,
|
data=filtered_data,
|
||||||
model=model,
|
model=model,
|
||||||
results=individual_result,
|
results=individual_result,
|
||||||
sample=sample,
|
cluster_number=cluster_number,
|
||||||
)
|
)
|
||||||
complete_results = complete_results.append(
|
complete_results = complete_results.append(
|
||||||
individual_result.append(model_results)
|
individual_result.append(model_results)
|
||||||
)
|
)
|
||||||
complete_results.set_index("model")
|
complete_results.set_index("model")
|
||||||
print(complete_results)
|
show_results(results=complete_results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue