Pad and mask the sequences in each batch

This commit is contained in:
coolneng 2021-07-05 19:55:31 +02:00
parent 70363a82a0
commit 1a1262b0b1
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
3 changed files with 11 additions and 3 deletions

View File

@ -11,6 +11,7 @@ class Hyperparameters:
learning_rate=0.004,
l2_rate=0.001,
log_directory="logs",
max_length=80,
):
self.data_file = data_file
self.label_file = label_file
@ -22,3 +23,4 @@ class Hyperparameters:
self.learning_rate = learning_rate
self.l2_rate = l2_rate
self.log_directory = log_directory
self.max_length = max_length

View File

@ -17,7 +17,8 @@ def build_model(hyperparams) -> Model:
"""
model = Sequential(
[
Input(shape=(None, len(BASES))),
Input(shape=(None, hyperparams.max_length, len(BASES))),
Masking(mask_value=-1),
Conv1D(
filters=16,
kernel_size=5,

View File

@ -115,8 +115,13 @@ def read_dataset(filepath, hyperparams) -> TFRecordDataset:
data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat(
count=hyperparams.epochs
batched_dataset = shuffled_dataset.padded_batch(
batch_size=hyperparams.batch_size,
padded_shapes=(
[hyperparams.max_length, len(BASES)],
[hyperparams.max_length, len(BASES)],
),
padding_values=-1.0,
)
return batched_dataset