diff --git a/src/preprocessing.py b/src/preprocessing.py index b2beb41..a2673f8 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -3,7 +3,7 @@ from typing import List, Tuple from Bio.motifs import create from Bio.SeqIO import parse from numpy.random import random -from tensorflow import Tensor, int64, stack +from tensorflow import Tensor, int64, stack, cast, int32 from tensorflow.sparse import to_dense from tensorflow.data import TFRecordDataset from tensorflow.io import ( @@ -77,6 +77,22 @@ def create_dataset( test.write(element) +def transform_features(parsed_features) -> List[Tensor]: + """ + Cast and transform the parsed features of an Example into a list of Tensors + """ + for base in BASES: + parsed_features[f"{base}_counts"] = cast( + parsed_features[f"{base}_counts"], int32 + ) + parsed_features["sequence"] = cast(parsed_features["sequence"], int32) + parsed_features["label"] = cast(parsed_features["label"], int32) + parsed_features["sequence"] = to_dense(parsed_features["sequence"]) + parsed_features["label"] = to_dense(parsed_features["label"]) + features = list(parsed_features.values())[:-1] + return features + + def process_input(byte_string) -> Tuple[Tensor, Tensor]: """ Parse a byte-string into an Example object @@ -90,9 +106,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]: "label": VarLenFeature(dtype=int64), } parsed_features = parse_single_example(byte_string, features=schema) - parsed_features["sequence"] = to_dense(parsed_features["sequence"]) - parsed_features["label"] = to_dense(parsed_features["label"]) - features = list(parsed_features.values())[:-1] + features = transform_features(parsed_features) return stack(features, axis=-1), parsed_features["label"]