diff --git a/src/preprocessing.py b/src/preprocessing.py index 311781a..e2995b4 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -5,6 +5,7 @@ from Bio.SeqIO import parse from numpy.random import random from tensorflow import Tensor, int64 from tensorflow.data import TFRecordDataset +from tensorflow.data import AUTOTUNE, TFRecordDataset from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.sparse import to_dense from tensorflow.train import Example, Feature, Features, Int64List @@ -104,7 +105,7 @@ def read_dataset(filepath) -> TFRecordDataset: Read TFRecords files and generate a dataset """ data_input = TFRecordDataset(filenames=filepath) - dataset = data_input.map(map_func=process_input) + dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) return batched_dataset