Compare commits

...

1 Commits

Author SHA1 Message Date
coolneng 0912600fdc
Remove dense Tensor transformation 2021-06-23 18:28:09 +02:00
1 changed files with 7 additions and 25 deletions

View File

@ -1,11 +1,10 @@
from typing import Dict, List, Tuple
from typing import List, Tuple
from Bio.SeqIO import parse
from numpy.random import random
from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List
from constants import *
@ -38,44 +37,28 @@ def read_fastq(data_file, label_file) -> List[bytes]:
examples = []
with open(data_file) as data, open(label_file) as labels:
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
example = generate_example(
sequence=str(element.seq),
label=str(label.seq),
)
example = generate_example(sequence=str(element.seq), label=str(label.seq))
examples.append(example)
return examples
def create_dataset(
data_file, label_file, train_eval_test_split=[0.8, 0.1, 0.1]
) -> None:
def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None:
"""
Create a training, evaluation and test dataset with a 80/10/30 split respectively
Create a training, evaluation and test dataset with a 80/10/10 split respectively
"""
data = read_fastq(data_file, label_file)
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
TEST_DATASET
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
for element in data:
if random() < train_eval_test_split[0]:
if random() < dataset_split[0]:
training.write(element)
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
elif random() < dataset_split[0] + dataset_split[1]:
evaluation.write(element)
else:
test.write(element)
def transform_features(parsed_features) -> Dict[str, Tensor]:
"""
Transform the parsed features of an Example into a list of dense Tensors
"""
features = {}
sparse_features = ["sequence", "label"]
for element in sparse_features:
features[element] = to_dense(parsed_features[element])
return features
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"""
Parse a byte-string into an Example object
@ -84,8 +67,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"sequence": VarLenFeature(dtype=int64),
"label": VarLenFeature(dtype=int64),
}
parsed_features = parse_single_example(byte_string, features=schema)
features = transform_features(parsed_features)
features = parse_single_example(byte_string, features=schema)
return features["sequence"], features["label"]