Fix sample size selection

This commit is contained in:
coolneng 2020-12-11 14:37:01 +01:00
parent 8bcc7fa7bc
commit 6fe18d594c
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 6 additions and 6 deletions

View File

@ -35,6 +35,7 @@ def predict_data(data, model, results, sample):
labels=prediction, labels=prediction,
metric="euclidean", metric="euclidean",
sample_size=sample, sample_size=sample,
random_state=42,
) )
populated_results = populate_results( populated_results = populate_results(
df=results, df=results,
@ -80,8 +81,7 @@ def plot_scatter_plot(results):
def print_dataframe(df): def print_dataframe(df):
df.set_index("model") df.set_index("model")
output_df = df.filter["clusters", "silhouette", "calinski", "time"] print(df)
print(output_df)
def show_results(results): def show_results(results):
@ -95,8 +95,8 @@ def show_results(results):
def create_result_dataframes(): def create_result_dataframes():
results = DataFrame( results = DataFrame(
columns=[ columns=[
"clusters",
"model", "model",
"clusters",
"prediction", "prediction",
"silhouette", "silhouette",
"calinski-harabasz", "calinski-harabasz",
@ -156,7 +156,7 @@ def main():
models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"] models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"]
if len(argv) != 4: if len(argv) != 4:
usage() usage()
case, sample = argv[2], argv[3] case, sample = argv[2], int(argv[3])
data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1])) data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
individual_result, complete_results = create_result_dataframes() individual_result, complete_results = create_result_dataframes()
case_data = construct_case(df=data, choice=case) case_data = construct_case(df=data, choice=case)
@ -171,8 +171,8 @@ def main():
complete_results = complete_results.append( complete_results = complete_results.append(
individual_result.append(model_results) individual_result.append(model_results)
) )
indexed_results = complete_results.set_index("model") complete_results.set_index("model")
show_results(results=indexed_results) print_dataframe(df=complete_results)
if __name__ == "__main__": if __name__ == "__main__":