From 02d20d4e72c1717f06c28318ebe905e2ef0101d4 Mon Sep 17 00:00:00 2001 From: coolneng Date: Sat, 5 Jun 2021 20:34:59 +0200 Subject: [PATCH] Add reference sequence to each dataset instance --- src/preprocessing.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/src/preprocessing.py b/src/preprocessing.py index 98fb80a..caf8019 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple from Bio.motifs import create from Bio.SeqIO import parse @@ -8,10 +8,10 @@ from tensorflow.data import TFRecordDataset from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example from tensorflow.train import Example, Feature, Features, FloatList, Int64List -from constants import BASES, BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET +from constants import * -def generate_example(sequence, weight_matrix) -> bytes: +def generate_example(sequence, reference_sequence, weight_matrix) -> bytes: """ Create a binary-string for each sequence containing the sequence and the bases' frequency """ @@ -19,6 +19,9 @@ def generate_example(sequence, weight_matrix) -> bytes: "sequence": Feature( int64_list=Int64List(value=list(encode_sequence(sequence))) ), + "reference_sequence": Feature( + int64_list=Int64List(value=list(encode_sequence(reference_sequence))) + ), "A_counts": Feature(float_list=FloatList(value=[weight_matrix["A"][0]])), "C_counts": Feature(float_list=FloatList(value=[weight_matrix["C"][0]])), "G_counts": Feature(float_list=FloatList(value=[weight_matrix["G"][0]])), @@ -36,16 +39,19 @@ def encode_sequence(sequence) -> List[int]: return encoded_sequence -def read_fastq(filepath) -> List[bytes]: +def read_fastq(data_file, label_file) -> List[bytes]: """ - Parse a FASTQ file and generate a List of serialized Examples + Parses a data and a label FASTQ files and generates a List of serialized Examples """ examples = [] - with open(filepath) as handle: - for row in parse(handle, "fastq"): - sequence = str(row.seq) - motifs = create(row.seq) - example = generate_example(sequence=sequence, weight_matrix=motifs.pwm) + with open(data_file) as data, open(label_file) as labels: + for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): + motifs = create([element.seq]) + example = generate_example( + sequence=str(element.seq), + reference_sequence=str(label.seq), + weight_matrix=motifs.pwm, + ) examples.append(example) return examples @@ -54,7 +60,7 @@ def create_dataset(filepath) -> None: """ Create a training and test dataset with a 70/30 split respectively """ - data = read_fastq(filepath) + data = read_fastq(data_file, label_file) train_test_split = 0.7 with TFRecordWriter(TRAIN_DATASET) as train, TFRecordWriter(TEST_DATASET) as test: for element in data: @@ -70,6 +76,7 @@ def process_input(byte_string) -> Example: """ schema = { "sequence": FixedLenFeature(shape=[], dtype=int64), + "reference_sequence": FixedLenFeature(shape=[], dtype=int64), "A_counts": FixedLenFeature(shape=[], dtype=float32), "C_counts": FixedLenFeature(shape=[], dtype=float32), "G_counts": FixedLenFeature(shape=[], dtype=float32), @@ -78,12 +85,19 @@ def process_input(byte_string) -> Example: return parse_single_example(byte_string, features=schema) -def read_dataset() -> TFRecordDataset: +def read_dataset(filepath) -> TFRecordDataset: """ Read TFRecords files and generate a dataset """ - data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET]) + data_input = TFRecordDataset(filenames=filepath) dataset = data_input.map(map_func=process_input) - shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True) + shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) return batched_dataset + + +def dataset_creation(data_file, label_file) -> Tuple[TFRecordDataset, TFRecordDataset]: + create_dataset(data_file, label_file) + train_data = read_dataset(TRAIN_DATASET) + test_data = read_dataset(TEST_DATASET) + return train_data, test_data