Create a validation set

This commit is contained in:
coolneng 2021-06-06 00:03:39 +02:00
parent 38903c5737
commit 8870da8543
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
2 changed files with 15 additions and 6 deletions

View File

@ -1,7 +1,9 @@
BASES = "ACGT" BASES = "ACGT"
TRAIN_DATASET = "data/train_data.tfrecords" TRAIN_DATASET = "data/train_data.tfrecords"
TEST_DATASET = "data/test_data.tfrecords" TEST_DATASET = "data/test_data.tfrecords"
EVAL_DATASET = "data/eval_data.tfrecords"
EPOCHS = 1000 EPOCHS = 1000
BATCH_SIZE = 256 BATCH_SIZE = 256
LEARNING_RATE = 0.004 LEARNING_RATE = 0.004
L2 = 0.001 L2 = 0.001
LOG_DIR = "logs"

View File

@ -59,11 +59,15 @@ def create_dataset(data_file, label_file) -> None:
Create a training and test dataset with a 70/30 split respectively Create a training and test dataset with a 70/30 split respectively
""" """
data = read_fastq(data_file, label_file) data = read_fastq(data_file, label_file)
train_test_split = 0.7 train_eval_test_split = [0.8, 0.1, 0.1]
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
TEST_DATASET
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
for element in data: for element in data:
if random() < train_test_split: if random() < train_eval_test_split[0]:
train.write(element) training.write(element)
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
evaluation.write(element)
else: else:
test.write(element) test.write(element)
@ -94,8 +98,11 @@ def read_dataset(filepath) -> TFRecordDataset:
return batched_dataset return batched_dataset
def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]: def dataset_creation(
data_file, label_file
) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]:
create_dataset(data_file, label_file) create_dataset(data_file, label_file)
train_data = read_dataset(TRAIN_DATASET) train_data = read_dataset(TRAIN_DATASET)
eval_data = read_dataset(EVAL_DATASET)
test_data = read_dataset(TEST_DATASET) test_data = read_dataset(TEST_DATASET)
return train_data, test_data return train_data, eval_data, test_data