Parallelize dataset transformations
This commit is contained in:
parent
b2f20f2070
commit
e9582d0883
|
@ -5,6 +5,7 @@ from Bio.SeqIO import parse
|
||||||
from numpy.random import random
|
from numpy.random import random
|
||||||
from tensorflow import Tensor, int64
|
from tensorflow import Tensor, int64
|
||||||
from tensorflow.data import TFRecordDataset
|
from tensorflow.data import TFRecordDataset
|
||||||
|
from tensorflow.data import AUTOTUNE, TFRecordDataset
|
||||||
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
|
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
|
||||||
from tensorflow.sparse import to_dense
|
from tensorflow.sparse import to_dense
|
||||||
from tensorflow.train import Example, Feature, Features, Int64List
|
from tensorflow.train import Example, Feature, Features, Int64List
|
||||||
|
@ -104,7 +105,7 @@ def read_dataset(filepath) -> TFRecordDataset:
|
||||||
Read TFRecords files and generate a dataset
|
Read TFRecords files and generate a dataset
|
||||||
"""
|
"""
|
||||||
data_input = TFRecordDataset(filenames=filepath)
|
data_input = TFRecordDataset(filenames=filepath)
|
||||||
dataset = data_input.map(map_func=process_input)
|
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
|
||||||
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
||||||
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
|
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
|
||||||
return batched_dataset
|
return batched_dataset
|
||||||
|
|
Loading…
Reference in New Issue