Parallelize dataset transformations

This commit is contained in:
coolneng 2021-06-24 19:30:46 +02:00
parent b2f20f2070
commit e9582d0883
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 2 additions and 1 deletions

View File

@ -5,6 +5,7 @@ from Bio.SeqIO import parse
from numpy.random import random
from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset
from tensorflow.data import AUTOTUNE, TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List
@ -104,7 +105,7 @@ def read_dataset(filepath) -> TFRecordDataset:
Read TFRecords files and generate a dataset
"""
data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input)
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
return batched_dataset