Compare commits

...

1 Commits

Author SHA1 Message Date
coolneng 0912600fdc
Remove dense Tensor transformation 2021-06-23 18:28:09 +02:00
1 changed files with 7 additions and 25 deletions

View File

@ -1,11 +1,10 @@
from typing import Dict, List, Tuple from typing import List, Tuple
from Bio.SeqIO import parse from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import Tensor, int64 from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List from tensorflow.train import Example, Feature, Features, Int64List
from constants import * from constants import *
@ -38,44 +37,28 @@ def read_fastq(data_file, label_file) -> List[bytes]:
examples = [] examples = []
with open(data_file) as data, open(label_file) as labels: with open(data_file) as data, open(label_file) as labels:
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
example = generate_example( example = generate_example(sequence=str(element.seq), label=str(label.seq))
sequence=str(element.seq),
label=str(label.seq),
)
examples.append(example) examples.append(example)
return examples return examples
def create_dataset( def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None:
data_file, label_file, train_eval_test_split=[0.8, 0.1, 0.1]
) -> None:
""" """
Create a training, evaluation and test dataset with a 80/10/30 split respectively Create a training, evaluation and test dataset with a 80/10/10 split respectively
""" """
data = read_fastq(data_file, label_file) data = read_fastq(data_file, label_file)
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter( with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
TEST_DATASET TEST_DATASET
) as test, TFRecordWriter(EVAL_DATASET) as evaluation: ) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
for element in data: for element in data:
if random() < train_eval_test_split[0]: if random() < dataset_split[0]:
training.write(element) training.write(element)
elif random() < train_eval_test_split[0] + train_eval_test_split[1]: elif random() < dataset_split[0] + dataset_split[1]:
evaluation.write(element) evaluation.write(element)
else: else:
test.write(element) test.write(element)
def transform_features(parsed_features) -> Dict[str, Tensor]:
"""
Transform the parsed features of an Example into a list of dense Tensors
"""
features = {}
sparse_features = ["sequence", "label"]
for element in sparse_features:
features[element] = to_dense(parsed_features[element])
return features
def process_input(byte_string) -> Tuple[Tensor, Tensor]: def process_input(byte_string) -> Tuple[Tensor, Tensor]:
""" """
Parse a byte-string into an Example object Parse a byte-string into an Example object
@ -84,8 +67,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
"sequence": VarLenFeature(dtype=int64), "sequence": VarLenFeature(dtype=int64),
"label": VarLenFeature(dtype=int64), "label": VarLenFeature(dtype=int64),
} }
parsed_features = parse_single_example(byte_string, features=schema) features = parse_single_example(byte_string, features=schema)
features = transform_features(parsed_features)
return features["sequence"], features["label"] return features["sequence"], features["label"]