diff --git a/src/preprocessing.py b/src/preprocessing.py index df705a6..04cabbe 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -1,11 +1,9 @@ -from typing import Dict, List, Tuple from Bio.SeqIO import parse from numpy.random import random from tensorflow import Tensor, int64 from tensorflow.data import TFRecordDataset from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example -from tensorflow.sparse import to_dense from tensorflow.train import Example, Feature, Features, Int64List from constants import * @@ -38,44 +36,28 @@ def read_fastq(data_file, label_file) -> List[bytes]: examples = [] with open(data_file) as data, open(label_file) as labels: for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): - example = generate_example( - sequence=str(element.seq), - label=str(label.seq), - ) + example = generate_example(sequence=str(element.seq), label=str(label.seq)) examples.append(example) return examples -def create_dataset( - data_file, label_file, train_eval_test_split=[0.8, 0.1, 0.1] -) -> None: +def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None: """ - Create a training, evaluation and test dataset with a 80/10/30 split respectively + Create a training, evaluation and test dataset with a 80/10/10 split respectively """ data = read_fastq(data_file, label_file) with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter( TEST_DATASET ) as test, TFRecordWriter(EVAL_DATASET) as evaluation: for element in data: - if random() < train_eval_test_split[0]: + if random() < dataset_split[0]: training.write(element) - elif random() < train_eval_test_split[0] + train_eval_test_split[1]: + elif random() < dataset_split[0] + dataset_split[1]: evaluation.write(element) else: test.write(element) -def transform_features(parsed_features) -> Dict[str, Tensor]: - """ - Transform the parsed features of an Example into a list of dense Tensors - """ - features = {} - sparse_features = ["sequence", "label"] - for element in sparse_features: - features[element] = to_dense(parsed_features[element]) - return features - - def process_input(byte_string) -> Tuple[Tensor, Tensor]: """ Parse a byte-string into an Example object @@ -84,8 +66,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]: "sequence": VarLenFeature(dtype=int64), "label": VarLenFeature(dtype=int64), } - parsed_features = parse_single_example(byte_string, features=schema) - features = transform_features(parsed_features) + features = parse_single_example(byte_string, features=schema) return features["sequence"], features["label"]