diff --git a/src/preprocessing.py b/src/preprocessing.py index e2995b4..e1087a8 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -3,8 +3,7 @@ from typing import Dict, List, Tuple from Bio.pairwise2 import align from Bio.SeqIO import parse from numpy.random import random -from tensorflow import Tensor, int64 -from tensorflow.data import TFRecordDataset +from tensorflow import Tensor, int64, one_hot from tensorflow.data import AUTOTUNE, TFRecordDataset from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.sparse import to_dense @@ -84,6 +83,7 @@ def transform_features(parsed_features) -> Dict[str, Tensor]: sparse_features = ["sequence", "label"] for element in sparse_features: features[element] = to_dense(parsed_features[element]) + features[element] = one_hot(features[element], depth=len(BASES)) return features