From 19ed847d122741bcb5a9149b493a8094da666c4d Mon Sep 17 00:00:00 2001 From: coolneng Date: Mon, 14 Jun 2021 19:33:42 +0200 Subject: [PATCH] Convert sequence and label to VarLenFeature --- src/preprocessing.py | 52 +++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/preprocessing.py b/src/preprocessing.py index 102135c..b2beb41 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -3,27 +3,31 @@ from typing import List, Tuple from Bio.motifs import create from Bio.SeqIO import parse from numpy.random import random -from tensorflow import float32, int64 +from tensorflow import Tensor, int64, stack +from tensorflow.sparse import to_dense from tensorflow.data import TFRecordDataset -from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example -from tensorflow.train import Example, Feature, Features, FloatList, Int64List +from tensorflow.io import ( + FixedLenFeature, + TFRecordWriter, + VarLenFeature, + parse_single_example, +) +from tensorflow.train import Example, Feature, Features, Int64List from constants import * -def generate_example(sequence, label, weight_matrix) -> bytes: +def generate_example(sequence, label, base_counts) -> bytes: """ - Create a binary-string for each sequence containing the sequence and the bases' frequency + Create a binary-string for each sequence containing the sequence and the bases' counts """ schema = { - "sequence": Feature( - int64_list=Int64List(value=list(encode_sequence(sequence))) - ), - "label": Feature(int64_list=Int64List(value=list(encode_sequence(label)))), - "A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])), - "C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])), - "G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])), - "T_counts": Feature(float_list=FloatList(value=weight_matrix["T"])), + "A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])), + "C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])), + "G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])), + "T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])), + "sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))), + "label": Feature(int64_list=Int64List(value=encode_sequence(label))), } example = Example(features=Features(feature=schema)) return example.SerializeToString() @@ -48,7 +52,7 @@ def read_fastq(data_file, label_file) -> List[bytes]: example = generate_example( sequence=str(element.seq), label=str(label.seq), - weight_matrix=motifs.pwm, + base_counts=motifs.counts, ) examples.append(example) return examples @@ -73,19 +77,23 @@ def create_dataset( test.write(element) -def process_input(byte_string) -> Example: +def process_input(byte_string) -> Tuple[Tensor, Tensor]: """ Parse a byte-string into an Example object """ schema = { - "sequence": FixedLenFeature(shape=[], dtype=int64), - "label": FixedLenFeature(shape=[], dtype=int64), - "A_counts": FixedLenFeature(shape=[], dtype=float32), - "C_counts": FixedLenFeature(shape=[], dtype=float32), - "G_counts": FixedLenFeature(shape=[], dtype=float32), - "T_counts": FixedLenFeature(shape=[], dtype=float32), + "A_counts": FixedLenFeature(shape=[1], dtype=int64), + "C_counts": FixedLenFeature(shape=[1], dtype=int64), + "G_counts": FixedLenFeature(shape=[1], dtype=int64), + "T_counts": FixedLenFeature(shape=[1], dtype=int64), + "sequence": VarLenFeature(dtype=int64), + "label": VarLenFeature(dtype=int64), } - return parse_single_example(byte_string, features=schema) + parsed_features = parse_single_example(byte_string, features=schema) + parsed_features["sequence"] = to_dense(parsed_features["sequence"]) + parsed_features["label"] = to_dense(parsed_features["label"]) + features = list(parsed_features.values())[:-1] + return stack(features, axis=-1), parsed_features["label"] def read_dataset(filepath) -> TFRecordDataset: