Refactor the casting function using a loop

This commit is contained in:
coolneng 2021-06-15 00:22:55 +02:00
parent 379303b440
commit 7029b64906
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 4 additions and 4 deletions

View File

@ -81,14 +81,14 @@ def transform_features(parsed_features) -> List[Tensor]:
"""
Cast and transform the parsed features of an Example into a list of Tensors
"""
sparse_features = ["sequence", "label"]
for feature in sparse_features:
parsed_features[feature] = cast(parsed_features[feature], int32)
parsed_features[feature] = to_dense(parsed_features[feature])
for base in BASES:
parsed_features[f"{base}_counts"] = cast(
parsed_features[f"{base}_counts"], int32
)
parsed_features["sequence"] = cast(parsed_features["sequence"], int32)
parsed_features["label"] = cast(parsed_features["label"], int32)
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
parsed_features["label"] = to_dense(parsed_features["label"])
features = list(parsed_features.values())[:-1]
return features