diff --git a/src/preprocessing.py b/src/preprocessing.py index e2995b4..0b0c52a 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -3,8 +3,7 @@ from typing import Dict, List, Tuple from Bio.pairwise2 import align from Bio.SeqIO import parse from numpy.random import random -from tensorflow import Tensor, int64 -from tensorflow.data import TFRecordDataset +from tensorflow import Tensor, int64, one_hot from tensorflow.data import AUTOTUNE, TFRecordDataset from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.sparse import to_dense @@ -78,12 +77,13 @@ def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None def transform_features(parsed_features) -> Dict[str, Tensor]: """ - Transform the parsed features of an Example into a list of dense Tensors + Transform the parsed features of an Example into a list of dense one hot encoded Tensors """ features = {} sparse_features = ["sequence", "label"] for element in sparse_features: features[element] = to_dense(parsed_features[element]) + features[element] = one_hot(features[element], depth=len(BASES)) return features