Pad and mask the sequences in each batch
This commit is contained in:
parent
70363a82a0
commit
1a1262b0b1
|
@ -11,6 +11,7 @@ class Hyperparameters:
|
|||
learning_rate=0.004,
|
||||
l2_rate=0.001,
|
||||
log_directory="logs",
|
||||
max_length=80,
|
||||
):
|
||||
self.data_file = data_file
|
||||
self.label_file = label_file
|
||||
|
@ -22,3 +23,4 @@ class Hyperparameters:
|
|||
self.learning_rate = learning_rate
|
||||
self.l2_rate = l2_rate
|
||||
self.log_directory = log_directory
|
||||
self.max_length = max_length
|
||||
|
|
|
@ -17,7 +17,8 @@ def build_model(hyperparams) -> Model:
|
|||
"""
|
||||
model = Sequential(
|
||||
[
|
||||
Input(shape=(None, len(BASES))),
|
||||
Input(shape=(None, hyperparams.max_length, len(BASES))),
|
||||
Masking(mask_value=-1),
|
||||
Conv1D(
|
||||
filters=16,
|
||||
kernel_size=5,
|
||||
|
|
|
@ -115,8 +115,13 @@ def read_dataset(filepath, hyperparams) -> TFRecordDataset:
|
|||
data_input = TFRecordDataset(filenames=filepath)
|
||||
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
|
||||
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
||||
batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat(
|
||||
count=hyperparams.epochs
|
||||
batched_dataset = shuffled_dataset.padded_batch(
|
||||
batch_size=hyperparams.batch_size,
|
||||
padded_shapes=(
|
||||
[hyperparams.max_length, len(BASES)],
|
||||
[hyperparams.max_length, len(BASES)],
|
||||
),
|
||||
padding_values=-1.0,
|
||||
)
|
||||
return batched_dataset
|
||||
|
||||
|
|
Loading…
Reference in New Issue