diff --git a/src/hyperparameters.py b/src/hyperparameters.py index e7b5543..6e9704a 100644 --- a/src/hyperparameters.py +++ b/src/hyperparameters.py @@ -11,6 +11,7 @@ class Hyperparameters: learning_rate=0.004, l2_rate=0.001, log_directory="logs", + max_length=80, ): self.data_file = data_file self.label_file = label_file @@ -22,3 +23,4 @@ class Hyperparameters: self.learning_rate = learning_rate self.l2_rate = l2_rate self.log_directory = log_directory + self.max_length = max_length diff --git a/src/model.py b/src/model.py index 97ba5ef..907fefa 100644 --- a/src/model.py +++ b/src/model.py @@ -17,7 +17,8 @@ def build_model(hyperparams) -> Model: """ model = Sequential( [ - Input(shape=(None, len(BASES))), + Input(shape=(None, hyperparams.max_length, len(BASES))), + Masking(mask_value=-1), Conv1D( filters=16, kernel_size=5, diff --git a/src/preprocessing.py b/src/preprocessing.py index e86eeac..5a1e96d 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -115,8 +115,13 @@ def read_dataset(filepath, hyperparams) -> TFRecordDataset: data_input = TFRecordDataset(filenames=filepath) dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) - batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat( - count=hyperparams.epochs + batched_dataset = shuffled_dataset.padded_batch( + batch_size=hyperparams.batch_size, + padded_shapes=( + [hyperparams.max_length, len(BASES)], + [hyperparams.max_length, len(BASES)], + ), + padding_values=-1.0, ) return batched_dataset