diff --git a/src/processing.py b/src/processing.py index 514230f..a0bd7c3 100644 --- a/src/processing.py +++ b/src/processing.py @@ -2,6 +2,7 @@ from numpy import mean from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score from sklearn.model_selection import cross_val_score from sklearn.naive_bayes import GaussianNB +from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import scale from sklearn.svm import LinearSVC @@ -19,6 +20,8 @@ def choose_model(model): return KNeighborsClassifier(n_neighbors=10) elif model == "tree": return DecisionTreeClassifier(random_state=42) + elif model == "neuralnet": + return MLPClassifier(hidden_layer_sizes=10) def predict_data(data, target, model): @@ -53,7 +56,7 @@ def evaluate_performance(confusion_matrix, accuracy, cv_score, auc): def main(): data, target = parse_data(source="data/mamografia.csv", action="drop") - predict_data(data=data, target=target, model="gnb") + predict_data(data=data, target=target, model="neuralnet") if __name__ == "__main__":