Implement model inference of sequences

This commit is contained in:
coolneng 2021-07-06 02:59:37 +02:00
parent 1333a9256b
commit 92c6b54966
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
3 changed files with 47 additions and 18 deletions

View File

@ -1,26 +1,33 @@
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from time import time
from model import run from model import train_model, infer_sequence
def parse_arguments() -> Namespace: def parse_arguments() -> Namespace:
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( subparsers = parser.add_subparsers(dest="task")
parser_train = subparsers.add_parser("train")
parser_infer = subparsers.add_parser("infer")
parser_train.add_argument(
"data_file", help="FASTQ file containing the sequences with errors" "data_file", help="FASTQ file containing the sequences with errors"
) )
parser.add_argument( parser_train.add_argument(
"label_file", help="FASTQ file containing the sequences without errors" "label_file", help="FASTQ file containing the sequences without errors"
) )
parser_infer.add_argument("sequence", help="DNA sequence with errors")
return parser.parse_args() return parser.parse_args()
def execute_task(args):
if args.task == "train":
train_model(data_file=args.data_file, label_file=args.label_file)
else:
infer_sequence(sequence=args.sequence)
def main() -> None: def main() -> None:
args = parse_arguments() args = parse_arguments()
start_time = time() execute_task(args)
run(data_file=args.data_file, label_file=args.label_file)
end_time = time()
print(f"Elapsed time: {end_time - start_time}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,14 +1,17 @@
from random import seed from random import seed
from numpy import argmax
from tensorflow import one_hot
from tensorflow.keras import Model, Sequential from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import * from tensorflow.keras.layers import Dense, Dropout, Input, Masking
from tensorflow.keras.models import load_model
from tensorflow.keras.losses import categorical_crossentropy from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed from tensorflow.random import set_seed
from hyperparameters import Hyperparameters from hyperparameters import Hyperparameters
from preprocessing import BASES, dataset_creation from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence
def build_model(hyperparams) -> Model: def build_model(hyperparams) -> Model:
@ -20,19 +23,15 @@ def build_model(hyperparams) -> Model:
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))), Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
Masking(mask_value=-1), Masking(mask_value=-1),
Dense( Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
), ),
Dropout(rate=0.3), Dropout(rate=0.3),
Dense( Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
), ),
Dropout(rate=0.3), Dropout(rate=0.3),
Dense( Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
), ),
Dropout(rate=0.3), Dropout(rate=0.3),
Dense(units=len(BASES), activation="softmax"), Dense(units=len(BASES), activation="softmax"),
@ -56,7 +55,7 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}") print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
def run(data_file, label_file, seed_value=42) -> None: def train_model(data_file, label_file, seed_value=42) -> None:
""" """
Create a dataset, a model and runs training and evaluation on it Create a dataset, a model and runs training and evaluation on it
""" """
@ -69,3 +68,17 @@ def run(data_file, label_file, seed_value=42) -> None:
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data) model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
print("Training complete. Obtaining the model's metrics...") print("Training complete. Obtaining the model's metrics...")
show_metrics(model, eval_data, test_data) show_metrics(model, eval_data, test_data)
model.save("trained_model")
def infer_sequence(sequence) -> None:
"""
Predict the correct sequence, using the trained model
"""
model = load_model("trained_model")
encoded_sequence = encode_sequence(sequence)
one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES))
prediction = model.predict(one_hot_encoded_sequence)
encoded_prediction = argmax(prediction, axis=1)
final_prediction = decode_sequence(encoded_prediction)
print(f"Error-corrected sequence: {final_prediction}")

View File

@ -30,6 +30,15 @@ def encode_sequence(sequence) -> List[int]:
return encoded_sequence return encoded_sequence
def decode_sequence(sequence) -> str:
"""
Decode an index encoded sequence back to the human readable format
"""
decoded_list = [BASES[element] for element in sequence]
sequence = "".join(decoded_list)
return sequence
def prepare_sequences(sequence, label): def prepare_sequences(sequence, label):
""" """
Align and encode the sequences to obtain a fixed length output in order to perform batching Align and encode the sequences to obtain a fixed length output in order to perform batching