Cast the parsed features to int32
This commit is contained in:
parent
d2e5fd0fa3
commit
379303b440
|
@ -3,7 +3,7 @@ from typing import List, Tuple
|
|||
from Bio.motifs import create
|
||||
from Bio.SeqIO import parse
|
||||
from numpy.random import random
|
||||
from tensorflow import Tensor, int64, stack
|
||||
from tensorflow import Tensor, int64, stack, cast, int32
|
||||
from tensorflow.sparse import to_dense
|
||||
from tensorflow.data import TFRecordDataset
|
||||
from tensorflow.io import (
|
||||
|
@ -77,6 +77,22 @@ def create_dataset(
|
|||
test.write(element)
|
||||
|
||||
|
||||
def transform_features(parsed_features) -> List[Tensor]:
|
||||
"""
|
||||
Cast and transform the parsed features of an Example into a list of Tensors
|
||||
"""
|
||||
for base in BASES:
|
||||
parsed_features[f"{base}_counts"] = cast(
|
||||
parsed_features[f"{base}_counts"], int32
|
||||
)
|
||||
parsed_features["sequence"] = cast(parsed_features["sequence"], int32)
|
||||
parsed_features["label"] = cast(parsed_features["label"], int32)
|
||||
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
|
||||
parsed_features["label"] = to_dense(parsed_features["label"])
|
||||
features = list(parsed_features.values())[:-1]
|
||||
return features
|
||||
|
||||
|
||||
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Parse a byte-string into an Example object
|
||||
|
@ -90,9 +106,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
|||
"label": VarLenFeature(dtype=int64),
|
||||
}
|
||||
parsed_features = parse_single_example(byte_string, features=schema)
|
||||
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
|
||||
parsed_features["label"] = to_dense(parsed_features["label"])
|
||||
features = list(parsed_features.values())[:-1]
|
||||
features = transform_features(parsed_features)
|
||||
return stack(features, axis=-1), parsed_features["label"]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue