From 8870da85430346b371cae154cfab89b3b5db50ab Mon Sep 17 00:00:00 2001 From: coolneng Date: Sun, 6 Jun 2021 00:03:39 +0200 Subject: [PATCH] Create a validation set --- src/constants.py | 2 ++ src/preprocessing.py | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/constants.py b/src/constants.py index c44726b..e4c3746 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,7 +1,9 @@ BASES = "ACGT" TRAIN_DATASET = "data/train_data.tfrecords" TEST_DATASET = "data/test_data.tfrecords" +EVAL_DATASET = "data/eval_data.tfrecords" EPOCHS = 1000 BATCH_SIZE = 256 LEARNING_RATE = 0.004 L2 = 0.001 +LOG_DIR = "logs" diff --git a/src/preprocessing.py b/src/preprocessing.py index b9353d1..a08b6cf 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -59,11 +59,15 @@ def create_dataset(data_file, label_file) -> None: Create a training and test dataset with a 70/30 split respectively """ data = read_fastq(data_file, label_file) - train_test_split = 0.7 - with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: + train_eval_test_split = [0.8, 0.1, 0.1] + with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter( + TEST_DATASET + ) as test, TFRecordWriter(EVAL_DATASET) as evaluation: for element in data: - if random() < train_test_split: - train.write(element) + if random() < train_eval_test_split[0]: + training.write(element) + elif random() < train_eval_test_split[0] + train_eval_test_split[1]: + evaluation.write(element) else: test.write(element) @@ -94,8 +98,11 @@ def read_dataset(filepath) -> TFRecordDataset: return batched_dataset -def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]: +def dataset_creation( + data_file, label_file +) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]: create_dataset(data_file, label_file) train_data = read_dataset(TRAIN_DATASET) + eval_data = read_dataset(EVAL_DATASET) test_data = read_dataset(TEST_DATASET) - return train_data, test_data + return train_data, eval_data, test_data