Create a validation set
This commit is contained in:
parent
38903c5737
commit
8870da8543
|
@ -1,7 +1,9 @@
|
||||||
BASES = "ACGT"
|
BASES = "ACGT"
|
||||||
TRAIN_DATASET = "data/train_data.tfrecords"
|
TRAIN_DATASET = "data/train_data.tfrecords"
|
||||||
TEST_DATASET = "data/test_data.tfrecords"
|
TEST_DATASET = "data/test_data.tfrecords"
|
||||||
|
EVAL_DATASET = "data/eval_data.tfrecords"
|
||||||
EPOCHS = 1000
|
EPOCHS = 1000
|
||||||
BATCH_SIZE = 256
|
BATCH_SIZE = 256
|
||||||
LEARNING_RATE = 0.004
|
LEARNING_RATE = 0.004
|
||||||
L2 = 0.001
|
L2 = 0.001
|
||||||
|
LOG_DIR = "logs"
|
||||||
|
|
|
@ -59,11 +59,15 @@ def create_dataset(data_file, label_file) -> None:
|
||||||
Create a training and test dataset with a 70/30 split respectively
|
Create a training and test dataset with a 70/30 split respectively
|
||||||
"""
|
"""
|
||||||
data = read_fastq(data_file, label_file)
|
data = read_fastq(data_file, label_file)
|
||||||
train_test_split = 0.7
|
train_eval_test_split = [0.8, 0.1, 0.1]
|
||||||
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test:
|
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
|
||||||
|
TEST_DATASET
|
||||||
|
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
|
||||||
for element in data:
|
for element in data:
|
||||||
if random() < train_test_split:
|
if random() < train_eval_test_split[0]:
|
||||||
train.write(element)
|
training.write(element)
|
||||||
|
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
|
||||||
|
evaluation.write(element)
|
||||||
else:
|
else:
|
||||||
test.write(element)
|
test.write(element)
|
||||||
|
|
||||||
|
@ -94,8 +98,11 @@ def read_dataset(filepath) -> TFRecordDataset:
|
||||||
return batched_dataset
|
return batched_dataset
|
||||||
|
|
||||||
|
|
||||||
def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]:
|
def dataset_creation(
|
||||||
|
data_file, label_file
|
||||||
|
) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]:
|
||||||
create_dataset(data_file, label_file)
|
create_dataset(data_file, label_file)
|
||||||
train_data = read_dataset(TRAIN_DATASET)
|
train_data = read_dataset(TRAIN_DATASET)
|
||||||
|
eval_data = read_dataset(EVAL_DATASET)
|
||||||
test_data = read_dataset(TEST_DATASET)
|
test_data = read_dataset(TEST_DATASET)
|
||||||
return train_data, test_data
|
return train_data, eval_data, test_data
|
||||||
|
|
Loading…
Reference in New Issue