Compare commits

...

2 Commits

Author SHA1 Message Date
coolneng 2920db70b4
Create a validation set 2021-06-06 00:03:39 +02:00
coolneng 38903c5737
Rename ref_sequence to label 2021-06-06 00:03:15 +02:00
1 changed files with 18 additions and 13 deletions

View File

@ -11,7 +11,7 @@ from tensorflow.train import Example, Feature, Features, FloatList, Int64List
from constants import * from constants import *
def generate_example(sequence, reference_sequence, weight_matrix) -> bytes: def generate_example(sequence, label, weight_matrix) -> bytes:
""" """
Create a binary-string for each sequence containing the sequence and the bases' frequency Create a binary-string for each sequence containing the sequence and the bases' frequency
""" """
@ -19,9 +19,7 @@ def generate_example(sequence, reference_sequence, weight_matrix) -> bytes:
"sequence": Feature( "sequence": Feature(
int64_list=Int64List(value=list(encode_sequence(sequence))) int64_list=Int64List(value=list(encode_sequence(sequence)))
), ),
"reference_sequence": Feature( "label": Feature(int64_list=Int64List(value=list(encode_sequence(label)))),
int64_list=Int64List(value=list(encode_sequence(reference_sequence)))
),
"A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])), "A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])),
"C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])), "C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])),
"G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])), "G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])),
@ -49,23 +47,27 @@ def read_fastq(data_file, label_file) -> List[bytes]:
motifs = create([element.seq]) motifs = create([element.seq])
example = generate_example( example = generate_example(
sequence=str(element.seq), sequence=str(element.seq),
reference_sequence=str(label.seq), label=str(label.seq),
weight_matrix=motifs.pwm, weight_matrix=motifs.pwm,
) )
examples.append(example) examples.append(example)
return examples return examples
def create_dataset(filepath) -> None: def create_dataset(data_file, label_file) -> None:
""" """
Create a training and test dataset with a 70/30 split respectively Create a training and test dataset with a 70/30 split respectively
""" """
data = read_fastq(data_file, label_file) data = read_fastq(data_file, label_file)
train_test_split = 0.7 train_eval_test_split = [0.8, 0.1, 0.1]
with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
TEST_DATASET
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
for element in data: for element in data:
if random() < train_test_split: if random() < train_eval_test_split[0]:
train.write(element) training.write(element)
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
evaluation.write(element)
else: else:
test.write(element) test.write(element)
@ -76,7 +78,7 @@ def process_input(byte_string) -> Example:
""" """
schema = { schema = {
"sequence": FixedLenFeature(shape=[], dtype=int64), "sequence": FixedLenFeature(shape=[], dtype=int64),
"reference_sequence": FixedLenFeature(shape=[], dtype=int64), "label": FixedLenFeature(shape=[], dtype=int64),
"A_counts": FixedLenFeature(shape=[], dtype=float32), "A_counts": FixedLenFeature(shape=[], dtype=float32),
"C_counts": FixedLenFeature(shape=[], dtype=float32), "C_counts": FixedLenFeature(shape=[], dtype=float32),
"G_counts": FixedLenFeature(shape=[], dtype=float32), "G_counts": FixedLenFeature(shape=[], dtype=float32),
@ -96,8 +98,11 @@ def read_dataset(filepath) -> TFRecordDataset:
return batched_dataset return batched_dataset
def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]: def dataset_creation(
data_file, label_file
) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]:
create_dataset(data_file, label_file) create_dataset(data_file, label_file)
train_data = read_dataset(TRAIN_DATASET) train_data = read_dataset(TRAIN_DATASET)
eval_data = read_dataset(EVAL_DATASET)
test_data = read_dataset(TEST_DATASET) test_data = read_dataset(TEST_DATASET)
return train_data, test_data return train_data, eval_data, test_data