From 2920db70b4fbff05dedc08be3fe68533245a41d5 Mon Sep 17 00:00:00 2001 From: coolneng Date: Sun, 6 Jun 2021 00:03:39 +0200 Subject: [PATCH] Create a validation set --- src/preprocessing.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/preprocessing.py b/src/preprocessing.py index b9353d1..a08b6cf 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -59,11 +59,15 @@ def create_dataset(data_file, label_file) -> None: Create a training and test dataset with a 70/30 split respectively """ data = read_fastq(data_file, label_file) - train_test_split = 0.7 - with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: + train_eval_test_split = [0.8, 0.1, 0.1] + with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter( + TEST_DATASET + ) as test, TFRecordWriter(EVAL_DATASET) as evaluation: for element in data: - if random() < train_test_split: - train.write(element) + if random() < train_eval_test_split[0]: + training.write(element) + elif random() < train_eval_test_split[0] + train_eval_test_split[1]: + evaluation.write(element) else: test.write(element) @@ -94,8 +98,11 @@ def read_dataset(filepath) -> TFRecordDataset: return batched_dataset -def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]: +def dataset_creation( + data_file, label_file +) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]: create_dataset(data_file, label_file) train_data = read_dataset(TRAIN_DATASET) + eval_data = read_dataset(EVAL_DATASET) test_data = read_dataset(TEST_DATASET) - return train_data, test_data + return train_data, eval_data, test_data