Remove dense Tensor transformation

This commit is contained in:
coolneng 2021-06-23 18:27:19 +02:00
parent 1e433c123f
commit 0912600fdc
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 7 additions and 25 deletions

View File

@ -1,11 +1,10 @@
from typing import Dict, List, Tuple
from typing import List, Tuple
from Bio.SeqIO import parse
from numpy.random import random
from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List
from constants import *
@ -38,44 +37,28 @@ def read_fastq(data_file, label_file) -> List[bytes]:
examples = []
with open(data_file) as data, open(label_file) as labels:
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
example = generate_example(
sequence=str(element.seq),
label=str(label.seq),
)
example = generate_example(sequence=str(element.seq), label=str(label.seq))
examples.append(example)
return examples
def create_dataset(
data_file, label_file, train_eval_test_split=[0.8, 0.1, 0.1]
) -> None:
def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None:
"""
Create a training, evaluation and test dataset with a 80/10/30 split respectively
Create a training, evaluation and test dataset with a 80/10/10 split respectively
"""
data = read_fastq(data_file, label_file)
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
TEST_DATASET
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
for element in data:
if random() < train_eval_test_split[0]:
if random() < dataset_split[0]:
training.write(element)
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
elif random() < dataset_split[0] + dataset_split[1]:
evaluation.write(element)
else:
test.write(element)
def transform_features(parsed_features) -> Dict[str, Tensor]:
"""
Transform the parsed features of an Example into a list of dense Tensors
"""
features = {}
sparse_features = ["sequence", "label"]
for element in sparse_features:
features[element] = to_dense(parsed_features[element])
return features
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"""
Parse a byte-string into an Example object
@ -84,8 +67,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"sequence": VarLenFeature(dtype=int64),
"label": VarLenFeature(dtype=int64),
}
parsed_features = parse_single_example(byte_string, features=schema)
features = transform_features(parsed_features)
features = parse_single_example(byte_string, features=schema)
return features["sequence"], features["label"]