From f8c1a54be3d7f61ad37a60d4b1b4cbbb298ea1ac Mon Sep 17 00:00:00 2001 From: coolneng Date: Thu, 3 Jun 2021 18:29:43 +0200 Subject: [PATCH] Apply index-based encoding to the DNA sequence --- src/preprocessing.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/preprocessing.py b/src/preprocessing.py index 34cb1b6..1ab533e 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -3,12 +3,12 @@ from typing import List from Bio.motifs import create from Bio.SeqIO import parse from numpy.random import random -from tensorflow import float32, string +from tensorflow import float32, int64 from tensorflow.data import TFRecordDataset from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example -from tensorflow.train import BytesList, Example, Feature, Features, FloatList +from tensorflow.train import Example, Feature, Features, FloatList, Int64List -from constants import BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET +from constants import BASES, BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET def generate_example(sequence, weight_matrix) -> bytes: @@ -16,7 +16,9 @@ def generate_example(sequence, weight_matrix) -> bytes: Create a binary-string for each sequence containing the sequence and the bases' frequency """ schema = { - "sequence": Feature(bytes_list=BytesList(value=[sequence.encode()])), + "sequence": Feature( + int64_list=Int64List(value=list(encode_sequence(sequence))) + ), "A_counts": Feature(float_list=FloatList(value=[weight_matrix["A"][0]])), "C_counts": Feature(float_list=FloatList(value=[weight_matrix["C"][0]])), "G_counts": Feature(float_list=FloatList(value=[weight_matrix["G"][0]])), @@ -26,6 +28,14 @@ def generate_example(sequence, weight_matrix) -> bytes: return example.SerializeToString() +def encode_sequence(sequence) -> List[int]: + """ + Encode the DNA sequence using the indices of the BASES constant + """ + encoded_sequence = [BASES.index(element) for element in sequence] + return encoded_sequence + + def parse_data(filepath) -> List[bytes]: """ Parse a FASTQ file and generate a List of serialized Examples @@ -56,7 +66,7 @@ def create_dataset(filepath) -> None: def process_input(byte_string): schema = { - "sequence": FixedLenFeature(shape=[], dtype=string), + "sequence": FixedLenFeature(shape=[], dtype=int64), "A_counts": FixedLenFeature(shape=[], dtype=float32), "C_counts": FixedLenFeature(shape=[], dtype=float32), "G_counts": FixedLenFeature(shape=[], dtype=float32),