Cast the parsed features to int32

This commit is contained in:
coolneng 2021-06-15 00:18:38 +02:00
parent d2e5fd0fa3
commit 379303b440
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 18 additions and 4 deletions

View File

@ -3,7 +3,7 @@ from typing import List, Tuple
from Bio.motifs import create
from Bio.SeqIO import parse
from numpy.random import random
from tensorflow import Tensor, int64, stack
from tensorflow import Tensor, int64, stack, cast, int32
from tensorflow.sparse import to_dense
from tensorflow.data import TFRecordDataset
from tensorflow.io import (
@ -77,6 +77,22 @@ def create_dataset(
test.write(element)
def transform_features(parsed_features) -> List[Tensor]:
"""
Cast and transform the parsed features of an Example into a list of Tensors
"""
for base in BASES:
parsed_features[f"{base}_counts"] = cast(
parsed_features[f"{base}_counts"], int32
)
parsed_features["sequence"] = cast(parsed_features["sequence"], int32)
parsed_features["label"] = cast(parsed_features["label"], int32)
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
parsed_features["label"] = to_dense(parsed_features["label"])
features = list(parsed_features.values())[:-1]
return features
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"""
Parse a byte-string into an Example object
@ -90,9 +106,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"label": VarLenFeature(dtype=int64),
}
parsed_features = parse_single_example(byte_string, features=schema)
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
parsed_features["label"] = to_dense(parsed_features["label"])
features = list(parsed_features.values())[:-1]
features = transform_features(parsed_features)
return stack(features, axis=-1), parsed_features["label"]