Add processing module

This commit is contained in:
coolneng 2021-01-01 21:06:09 +01:00
parent 793ba5fffb
commit abfc877c7d
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 51 additions and 0 deletions

51
src/processing.py Normal file
View File

@ -0,0 +1,51 @@
from numpy import mean
from pandas import DataFrame
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import GradientBoostingClassifier
from preprocessing import parse_data, split_k_sets
def predict_data(train_data, train_target, test_data, test_ids):
model = GradientBoostingClassifier(random_state=42)
accuracy_scores = []
for train_index, test_index in split_k_sets(train_data):
model.fit(train_data.iloc[train_index], train_target.iloc[train_index])
prediction = model.predict(train_data.iloc[test_index])
accuracy_scores.append(
accuracy_score(train_target.iloc[test_index], prediction)
)
cv_score = cross_val_score(model, train_data, train_target)
evaluate_performance(
accuracy=mean(accuracy_scores),
cv_score=mean(cv_score),
)
predictions = model.predict(test_data)
export_results(ids=test_ids, prediction=predictions)
def evaluate_performance(accuracy, cv_score):
print("Accuracy Score: " + str(accuracy))
print("Cross validation score: " + str(cv_score))
def export_results(ids, prediction):
result_df = DataFrame({"id": ids, "Precio_cat": prediction})
result_df.to_csv(path_or_buf="data/results.csv", index=False)
def main():
train_data, train_target, test_data, test_ids = parse_data(
train="data/train.csv", test="data/test.csv"
)
predict_data(
train_data=train_data,
train_target=train_target,
test_data=test_data,
test_ids=test_ids,
)
if __name__ == "__main__":
main()