Compare commits

...

1 Commits

Author SHA1 Message Date
coolneng e9582d0883
Parallelize dataset transformations 2021-06-24 19:54:19 +02:00
1 changed files with 2 additions and 1 deletions

View File

@ -5,6 +5,7 @@ from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import Tensor, int64 from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.data import AUTOTUNE, TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List from tensorflow.train import Example, Feature, Features, Int64List
@ -104,7 +105,7 @@ def read_dataset(filepath) -> TFRecordDataset:
Read TFRecords files and generate a dataset Read TFRecords files and generate a dataset
""" """
data_input = TFRecordDataset(filenames=filepath) data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input) dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
return batched_dataset return batched_dataset