From b2f20f2070d0f4f1e88390e46fe7a5049126133f Mon Sep 17 00:00:00 2001 From: coolneng Date: Thu, 24 Jun 2021 17:10:07 +0200 Subject: [PATCH] Revert "Remove dense Tensor transformation" This reverts commit 0912600fdc9636c7b32557a00d793c14d0c0278a. --- src/preprocessing.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/preprocessing.py b/src/preprocessing.py index d4a8734..311781a 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Dict, List, Tuple from Bio.pairwise2 import align from Bio.SeqIO import parse @@ -6,6 +6,7 @@ from numpy.random import random from tensorflow import Tensor, int64 from tensorflow.data import TFRecordDataset from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example +from tensorflow.sparse import to_dense from tensorflow.train import Example, Feature, Features, Int64List from constants import * @@ -49,7 +50,10 @@ def read_fastq(data_file, label_file) -> List[bytes]: examples = [] with open(data_file) as data, open(label_file) as labels: for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): - example = generate_example(sequence=str(element.seq), label=str(label.seq)) + example = generate_example( + sequence=str(element.seq), + label=str(label.seq), + ) examples.append(example) return examples @@ -71,6 +75,17 @@ def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None test.write(element) +def transform_features(parsed_features) -> Dict[str, Tensor]: + """ + Transform the parsed features of an Example into a list of dense Tensors + """ + features = {} + sparse_features = ["sequence", "label"] + for element in sparse_features: + features[element] = to_dense(parsed_features[element]) + return features + + def process_input(byte_string) -> Tuple[Tensor, Tensor]: """ Parse a byte-string into an Example object @@ -79,7 +94,8 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]: "sequence": VarLenFeature(dtype=int64), "label": VarLenFeature(dtype=int64), } - features = parse_single_example(byte_string, features=schema) + parsed_features = parse_single_example(byte_string, features=schema) + features = transform_features(parsed_features) return features["sequence"], features["label"]