Move hyperparameters to a class

This commit is contained in:
coolneng 2021-07-05 03:24:54 +02:00
parent e07d0dcdbf
commit a3780c9761
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
4 changed files with 64 additions and 43 deletions

View File

@ -1,9 +0,0 @@
BASES = "ACGT-"
TRAIN_DATASET = "data/train_data.tfrecords"
TEST_DATASET = "data/test_data.tfrecords"
EVAL_DATASET = "data/eval_data.tfrecords"
EPOCHS = 1000
BATCH_SIZE = 1
LEARNING_RATE = 0.004
L2 = 0.001
LOG_DIR = "logs"

24
src/hyperparameters.py Normal file
View File

@ -0,0 +1,24 @@
class Hyperparameters:
def __init__(
self,
data_file,
label_file,
train_dataset="data/train_data.tfrecords",
test_dataset="data/test_data.tfrecords",
eval_dataset="data/eval_data.tfrecords",
epochs=1000,
batch_size=256,
learning_rate=0.004,
l2_rate=0.001,
log_directory="logs",
):
self.data_file = data_file
self.label_file = label_file
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.test_dataset = test_dataset
self.epochs = epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.l2_rate = l2_rate
self.log_directory = log_directory

View File

@ -2,17 +2,16 @@ from random import seed
from tensorflow.keras import Model, Sequential from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import * from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.losses import categorical_crossentropy from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed from tensorflow.random import set_seed
from constants import * from hyperparameters import Hyperparameters
from preprocessing import dataset_creation from preprocessing import BASES, dataset_creation
def build_model() -> Model: def build_model(hyperparams) -> Model:
""" """
Build the CNN model Build the CNN model
""" """
@ -20,23 +19,33 @@ def build_model() -> Model:
[ [
Input(shape=(None, len(BASES))), Input(shape=(None, len(BASES))),
Conv1D( Conv1D(
filters=16, kernel_size=5, activation="relu", kernel_regularizer=l2(L2) filters=16,
kernel_size=5,
activation="relu",
kernel_regularizer=l2(hyperparams.l2_rate),
), ),
MaxPool1D(pool_size=3, strides=1), MaxPool1D(pool_size=3, strides=1),
Conv1D( Conv1D(
filters=16, kernel_size=3, activation="relu", kernel_regularizer=l2(L2) filters=16,
kernel_size=3,
activation="relu",
kernel_regularizer=l2(hyperparams.l2_rate),
), ),
MaxPool1D(pool_size=3, strides=1), MaxPool1D(pool_size=3, strides=1),
GlobalAveragePooling1D(), GlobalAveragePooling1D(),
Dense(units=16, activation="relu", kernel_regularizer=l2(L2)), Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3), Dropout(rate=0.3),
Dense(units=16, activation="relu", kernel_regularizer=l2(L2)), Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3), Dropout(rate=0.3),
Dense(units=len(BASES), activation="softmax"), Dense(units=len(BASES), activation="softmax"),
] ]
) )
model.compile( model.compile(
optimizer=Adam(LEARNING_RATE), optimizer=Adam(hyperparams.learning_rate),
loss=categorical_crossentropy, loss=categorical_crossentropy,
metrics=["accuracy"], metrics=["accuracy"],
) )
@ -59,17 +68,12 @@ def run(data_file, label_file, seed_value=42) -> None:
""" """
seed(seed_value) seed(seed_value)
set_seed(seed_value) set_seed(seed_value)
train_data, eval_data, test_data = dataset_creation(data_file, label_file) hyperparams = Hyperparameters(data_file=data_file, label_file=label_file)
tensorboard = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0) train_data, eval_data, test_data = dataset_creation(hyperparams)
model = build_model() model = build_model(hyperparams)
print("Training the model") print("Training the model")
model.fit( model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
train_data, print("Training complete. Obtaining the model's metrics...")
epochs=EPOCHS,
validation_data=eval_data,
callbacks=[tensorboard],
)
print("Training complete. Obtaining final metrics...")
show_metrics(model, eval_data, test_data) show_metrics(model, eval_data, test_data)

View File

@ -9,7 +9,7 @@ from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List from tensorflow.train import Example, Feature, Features, Int64List
from constants import * BASES = "ACGT-"
def align_sequences(sequence, label) -> Tuple[str, str]: def align_sequences(sequence, label) -> Tuple[str, str]:
@ -43,26 +43,26 @@ def encode_sequence(sequence) -> List[int]:
return encoded_sequence return encoded_sequence
def read_fastq(data_file, label_file) -> List[bytes]: def read_fastq(hyperparams) -> List[bytes]:
""" """
Parses a data and a label FASTQ files and generates a List of serialized Examples Parses a data and a label FASTQ files and generates a List of serialized Examples
""" """
examples = [] examples = []
with open(data_file) as data, open(label_file) as labels: with open(hyperparams.data_file) as data, open(hyperparams.label_file) as labels:
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
example = generate_example(sequence=str(element.seq), label=str(label.seq)) example = generate_example(sequence=str(element.seq), label=str(label.seq))
examples.append(example) examples.append(example)
return examples return examples
def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None: def create_dataset(hyperparams, dataset_split=[0.8, 0.1, 0.1]) -> None:
""" """
Create a training, evaluation and test dataset with a 80/10/10 split respectively Create a training, evaluation and test dataset with a 80/10/10 split respectively
""" """
data = read_fastq(data_file, label_file) data = read_fastq(hyperparams)
with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter( with TFRecordWriter(hyperparams.train_dataset) as training, TFRecordWriter(
TEST_DATASET hyperparams.test_dataset
) as test, TFRecordWriter(EVAL_DATASET) as evaluation: ) as test, TFRecordWriter(hyperparams.eval_dataset) as evaluation:
for element in data: for element in data:
if random() < dataset_split[0]: if random() < dataset_split[0]:
training.write(element) training.write(element)
@ -97,25 +97,27 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
return features["sequence"], features["label"] return features["sequence"], features["label"]
def read_dataset(filepath) -> TFRecordDataset: def read_dataset(filepath, hyperparams) -> TFRecordDataset:
""" """
Read TFRecords files and generate a dataset Read TFRecords files and generate a dataset
""" """
data_input = TFRecordDataset(filenames=filepath) data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE) dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat(
count=hyperparams.epochs
)
return batched_dataset return batched_dataset
def dataset_creation( def dataset_creation(
data_file, label_file hyperparams,
) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]: ) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]:
""" """
Generate the TFRecord files and split them into training, validation and test data Generate the TFRecord files and split them into training, validation and test data
""" """
create_dataset(data_file, label_file) create_dataset(hyperparams)
train_data = read_dataset(TRAIN_DATASET) train_data = read_dataset(hyperparams.train_dataset, hyperparams)
eval_data = read_dataset(EVAL_DATASET) eval_data = read_dataset(hyperparams.eval_dataset, hyperparams)
test_data = read_dataset(TEST_DATASET) test_data = read_dataset(hyperparams.test_dataset, hyperparams)
return train_data, eval_data, test_data return train_data, eval_data, test_data