Parallelize dataset transformations

This commit is contained in:
coolneng 2021-06-24 19:30:46 +02:00
parent b2f20f2070
commit e9582d0883
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 2 additions and 1 deletions

View File

@ -5,6 +5,7 @@ from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import Tensor, int64 from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.data import AUTOTUNE, TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List from tensorflow.train import Example, Feature, Features, Int64List
@ -104,7 +105,7 @@ def read_dataset(filepath) -> TFRecordDataset:
Read TFRecords files and generate a dataset Read TFRecords files and generate a dataset
""" """
data_input = TFRecordDataset(filenames=filepath) data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input) dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
return batched_dataset return batched_dataset