Cast the parsed features to int32
This commit is contained in:
parent
d2e5fd0fa3
commit
379303b440
|
@ -3,7 +3,7 @@ from typing import List, Tuple
|
||||||
from Bio.motifs import create
|
from Bio.motifs import create
|
||||||
from Bio.SeqIO import parse
|
from Bio.SeqIO import parse
|
||||||
from numpy.random import random
|
from numpy.random import random
|
||||||
from tensorflow import Tensor, int64, stack
|
from tensorflow import Tensor, int64, stack, cast, int32
|
||||||
from tensorflow.sparse import to_dense
|
from tensorflow.sparse import to_dense
|
||||||
from tensorflow.data import TFRecordDataset
|
from tensorflow.data import TFRecordDataset
|
||||||
from tensorflow.io import (
|
from tensorflow.io import (
|
||||||
|
@ -77,6 +77,22 @@ def create_dataset(
|
||||||
test.write(element)
|
test.write(element)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_features(parsed_features) -> List[Tensor]:
|
||||||
|
"""
|
||||||
|
Cast and transform the parsed features of an Example into a list of Tensors
|
||||||
|
"""
|
||||||
|
for base in BASES:
|
||||||
|
parsed_features[f"{base}_counts"] = cast(
|
||||||
|
parsed_features[f"{base}_counts"], int32
|
||||||
|
)
|
||||||
|
parsed_features["sequence"] = cast(parsed_features["sequence"], int32)
|
||||||
|
parsed_features["label"] = cast(parsed_features["label"], int32)
|
||||||
|
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
|
||||||
|
parsed_features["label"] = to_dense(parsed_features["label"])
|
||||||
|
features = list(parsed_features.values())[:-1]
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Parse a byte-string into an Example object
|
Parse a byte-string into an Example object
|
||||||
|
@ -90,9 +106,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
||||||
"label": VarLenFeature(dtype=int64),
|
"label": VarLenFeature(dtype=int64),
|
||||||
}
|
}
|
||||||
parsed_features = parse_single_example(byte_string, features=schema)
|
parsed_features = parse_single_example(byte_string, features=schema)
|
||||||
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
|
features = transform_features(parsed_features)
|
||||||
parsed_features["label"] = to_dense(parsed_features["label"])
|
|
||||||
features = list(parsed_features.values())[:-1]
|
|
||||||
return stack(features, axis=-1), parsed_features["label"]
|
return stack(features, axis=-1), parsed_features["label"]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue