Implement model inference of sequences

This commit is contained in:
coolneng 2021-07-06 02:59:37 +02:00
parent 1333a9256b
commit 92c6b54966
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
3 changed files with 47 additions and 18 deletions

View File

@ -1,26 +1,33 @@
from argparse import ArgumentParser, Namespace
from time import time
from model import run
from model import train_model, infer_sequence
def parse_arguments() -> Namespace:
parser = ArgumentParser()
parser.add_argument(
subparsers = parser.add_subparsers(dest="task")
parser_train = subparsers.add_parser("train")
parser_infer = subparsers.add_parser("infer")
parser_train.add_argument(
"data_file", help="FASTQ file containing the sequences with errors"
)
parser.add_argument(
parser_train.add_argument(
"label_file", help="FASTQ file containing the sequences without errors"
)
parser_infer.add_argument("sequence", help="DNA sequence with errors")
return parser.parse_args()
def execute_task(args):
if args.task == "train":
train_model(data_file=args.data_file, label_file=args.label_file)
else:
infer_sequence(sequence=args.sequence)
def main() -> None:
args = parse_arguments()
start_time = time()
run(data_file=args.data_file, label_file=args.label_file)
end_time = time()
print(f"Elapsed time: {end_time - start_time}")
execute_task(args)
if __name__ == "__main__":

View File

@ -1,14 +1,17 @@
from random import seed
from numpy import argmax
from tensorflow import one_hot
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.layers import Dense, Dropout, Input, Masking
from tensorflow.keras.models import load_model
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed
from hyperparameters import Hyperparameters
from preprocessing import BASES, dataset_creation
from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence
def build_model(hyperparams) -> Model:
@ -20,19 +23,15 @@ def build_model(hyperparams) -> Model:
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
Masking(mask_value=-1),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(units=len(BASES), activation="softmax"),
@ -56,7 +55,7 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
def run(data_file, label_file, seed_value=42) -> None:
def train_model(data_file, label_file, seed_value=42) -> None:
"""
Create a dataset, a model and runs training and evaluation on it
"""
@ -69,3 +68,17 @@ def run(data_file, label_file, seed_value=42) -> None:
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
print("Training complete. Obtaining the model's metrics...")
show_metrics(model, eval_data, test_data)
model.save("trained_model")
def infer_sequence(sequence) -> None:
"""
Predict the correct sequence, using the trained model
"""
model = load_model("trained_model")
encoded_sequence = encode_sequence(sequence)
one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES))
prediction = model.predict(one_hot_encoded_sequence)
encoded_prediction = argmax(prediction, axis=1)
final_prediction = decode_sequence(encoded_prediction)
print(f"Error-corrected sequence: {final_prediction}")

View File

@ -30,6 +30,15 @@ def encode_sequence(sequence) -> List[int]:
return encoded_sequence
def decode_sequence(sequence) -> str:
"""
Decode an index encoded sequence back to the human readable format
"""
decoded_list = [BASES[element] for element in sequence]
sequence = "".join(decoded_list)
return sequence
def prepare_sequences(sequence, label):
"""
Align and encode the sequences to obtain a fixed length output in order to perform batching