diff --git a/src/model.py b/src/model.py index 39f7ab4..413b683 100644 --- a/src/model.py +++ b/src/model.py @@ -20,7 +20,7 @@ def build_model(hyperparams) -> Model: """ model = Sequential( [ - Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))), + Input(shape=(hyperparams.max_length, len(BASES))), Masking(mask_value=-1), Dense( units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)