Apply index-based encoding to the DNA sequence

This commit is contained in:
coolneng 2021-06-03 18:29:43 +02:00
parent d34e291085
commit f8c1a54be3
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 15 additions and 5 deletions

View File

@ -3,12 +3,12 @@ from typing import List
from Bio.motifs import create from Bio.motifs import create
from Bio.SeqIO import parse from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import float32, string from tensorflow import float32, int64
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
from tensorflow.train import BytesList, Example, Feature, Features, FloatList from tensorflow.train import Example, Feature, Features, FloatList, Int64List
from constants import BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET from constants import BASES, BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET
def generate_example(sequence, weight_matrix) -> bytes: def generate_example(sequence, weight_matrix) -> bytes:
@ -16,7 +16,9 @@ def generate_example(sequence, weight_matrix) -> bytes:
Create a binary-string for each sequence containing the sequence and the bases' frequency Create a binary-string for each sequence containing the sequence and the bases' frequency
""" """
schema = { schema = {
"sequence": Feature(bytes_list=BytesList(value=[sequence.encode()])), "sequence": Feature(
int64_list=Int64List(value=list(encode_sequence(sequence)))
),
"A_counts": Feature(float_list=FloatList(value=[weight_matrix["A"][0]])), "A_counts": Feature(float_list=FloatList(value=[weight_matrix["A"][0]])),
"C_counts": Feature(float_list=FloatList(value=[weight_matrix["C"][0]])), "C_counts": Feature(float_list=FloatList(value=[weight_matrix["C"][0]])),
"G_counts": Feature(float_list=FloatList(value=[weight_matrix["G"][0]])), "G_counts": Feature(float_list=FloatList(value=[weight_matrix["G"][0]])),
@ -26,6 +28,14 @@ def generate_example(sequence, weight_matrix) -> bytes:
return example.SerializeToString() return example.SerializeToString()
def encode_sequence(sequence) -> List[int]:
"""
Encode the DNA sequence using the indices of the BASES constant
"""
encoded_sequence = [BASES.index(element) for element in sequence]
return encoded_sequence
def parse_data(filepath) -> List[bytes]: def parse_data(filepath) -> List[bytes]:
""" """
Parse a FASTQ file and generate a List of serialized Examples Parse a FASTQ file and generate a List of serialized Examples
@ -56,7 +66,7 @@ def create_dataset(filepath) -> None:
def process_input(byte_string): def process_input(byte_string):
schema = { schema = {
"sequence": FixedLenFeature(shape=[], dtype=string), "sequence": FixedLenFeature(shape=[], dtype=int64),
"A_counts": FixedLenFeature(shape=[], dtype=float32), "A_counts": FixedLenFeature(shape=[], dtype=float32),
"C_counts": FixedLenFeature(shape=[], dtype=float32), "C_counts": FixedLenFeature(shape=[], dtype=float32),
"G_counts": FixedLenFeature(shape=[], dtype=float32), "G_counts": FixedLenFeature(shape=[], dtype=float32),