Simplify model selection logic
This commit is contained in:
parent
4652d46966
commit
de21b65ca0
|
@ -10,35 +10,15 @@ from sklearn.tree import DecisionTreeClassifier
|
||||||
from preprocessing import parse_data, split_k_sets
|
from preprocessing import parse_data, split_k_sets
|
||||||
|
|
||||||
|
|
||||||
def naive_bayes():
|
|
||||||
model = GaussianNB()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def linear_svc():
|
|
||||||
model = LinearSVC(random_state=42)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def k_nearest_neighbors():
|
|
||||||
model = KNeighborsClassifier(n_neighbors=10)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def decision_tree():
|
|
||||||
model = DecisionTreeClassifier(random_state=42)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def choose_model(model):
|
def choose_model(model):
|
||||||
if model == "gnb":
|
if model == "gnb":
|
||||||
return naive_bayes()
|
return GaussianNB()
|
||||||
elif model == "svc":
|
elif model == "svc":
|
||||||
return linear_svc()
|
return LinearSVC(random_state=42)
|
||||||
elif model == "knn":
|
elif model == "knn":
|
||||||
return k_nearest_neighbors()
|
return KNeighborsClassifier(n_neighbors=10)
|
||||||
elif model == "tree":
|
elif model == "tree":
|
||||||
return decision_tree()
|
return DecisionTreeClassifier(random_state=42)
|
||||||
|
|
||||||
|
|
||||||
def predict_data(data, target, model):
|
def predict_data(data, target, model):
|
||||||
|
|
Loading…
Reference in New Issue