Remove dense Tensor transformation
This commit is contained in:
parent
1e433c123f
commit
0912600fdc
|
@ -1,11 +1,10 @@
|
||||||
from typing import Dict, List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from Bio.SeqIO import parse
|
from Bio.SeqIO import parse
|
||||||
from numpy.random import random
|
from numpy.random import random
|
||||||
from tensorflow import Tensor, int64
|
from tensorflow import Tensor, int64
|
||||||
from tensorflow.data import TFRecordDataset
|
from tensorflow.data import TFRecordDataset
|
||||||
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
|
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
|
||||||
from tensorflow.sparse import to_dense
|
|
||||||
from tensorflow.train import Example, Feature, Features, Int64List
|
from tensorflow.train import Example, Feature, Features, Int64List
|
||||||
|
|
||||||
from constants import *
|
from constants import *
|
||||||
|
@ -38,44 +37,28 @@ def read_fastq(data_file, label_file) -> List[bytes]:
|
||||||
examples = []
|
examples = []
|
||||||
with open(data_file) as data, open(label_file) as labels:
|
with open(data_file) as data, open(label_file) as labels:
|
||||||
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
|
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
|
||||||
example = generate_example(
|
example = generate_example(sequence=str(element.seq), label=str(label.seq))
|
||||||
sequence=str(element.seq),
|
|
||||||
label=str(label.seq),
|
|
||||||
)
|
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(
|
def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None:
|
||||||
data_file, label_file, train_eval_test_split=[0.8, 0.1, 0.1]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Create a training, evaluation and test dataset with a 80/10/30 split respectively
|
Create a training, evaluation and test dataset with a 80/10/10 split respectively
|
||||||
"""
|
"""
|
||||||
data = read_fastq(data_file, label_file)
|
data = read_fastq(data_file, label_file)
|
||||||
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
|
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
|
||||||
TEST_DATASET
|
TEST_DATASET
|
||||||
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
|
) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
|
||||||
for element in data:
|
for element in data:
|
||||||
if random() < train_eval_test_split[0]:
|
if random() < dataset_split[0]:
|
||||||
training.write(element)
|
training.write(element)
|
||||||
elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
|
elif random() < dataset_split[0] + dataset_split[1]:
|
||||||
evaluation.write(element)
|
evaluation.write(element)
|
||||||
else:
|
else:
|
||||||
test.write(element)
|
test.write(element)
|
||||||
|
|
||||||
|
|
||||||
def transform_features(parsed_features) -> Dict[str, Tensor]:
|
|
||||||
"""
|
|
||||||
Transform the parsed features of an Example into a list of dense Tensors
|
|
||||||
"""
|
|
||||||
features = {}
|
|
||||||
sparse_features = ["sequence", "label"]
|
|
||||||
for element in sparse_features:
|
|
||||||
features[element] = to_dense(parsed_features[element])
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Parse a byte-string into an Example object
|
Parse a byte-string into an Example object
|
||||||
|
@ -84,8 +67,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
||||||
"sequence": VarLenFeature(dtype=int64),
|
"sequence": VarLenFeature(dtype=int64),
|
||||||
"label": VarLenFeature(dtype=int64),
|
"label": VarLenFeature(dtype=int64),
|
||||||
}
|
}
|
||||||
parsed_features = parse_single_example(byte_string, features=schema)
|
features = parse_single_example(byte_string, features=schema)
|
||||||
features = transform_features(parsed_features)
|
|
||||||
return features["sequence"], features["label"]
|
return features["sequence"], features["label"]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue