Implement model inference of sequences
This commit is contained in:
parent
1333a9256b
commit
92c6b54966
23
src/main.py
23
src/main.py
|
@ -1,26 +1,33 @@
|
|||
from argparse import ArgumentParser, Namespace
|
||||
from time import time
|
||||
|
||||
from model import run
|
||||
from model import train_model, infer_sequence
|
||||
|
||||
|
||||
def parse_arguments() -> Namespace:
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
subparsers = parser.add_subparsers(dest="task")
|
||||
parser_train = subparsers.add_parser("train")
|
||||
parser_infer = subparsers.add_parser("infer")
|
||||
parser_train.add_argument(
|
||||
"data_file", help="FASTQ file containing the sequences with errors"
|
||||
)
|
||||
parser.add_argument(
|
||||
parser_train.add_argument(
|
||||
"label_file", help="FASTQ file containing the sequences without errors"
|
||||
)
|
||||
parser_infer.add_argument("sequence", help="DNA sequence with errors")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def execute_task(args):
|
||||
if args.task == "train":
|
||||
train_model(data_file=args.data_file, label_file=args.label_file)
|
||||
else:
|
||||
infer_sequence(sequence=args.sequence)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_arguments()
|
||||
start_time = time()
|
||||
run(data_file=args.data_file, label_file=args.label_file)
|
||||
end_time = time()
|
||||
print(f"Elapsed time: {end_time - start_time}")
|
||||
execute_task(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
33
src/model.py
33
src/model.py
|
@ -1,14 +1,17 @@
|
|||
from random import seed
|
||||
|
||||
from numpy import argmax
|
||||
from tensorflow import one_hot
|
||||
from tensorflow.keras import Model, Sequential
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.layers import Dense, Dropout, Input, Masking
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.losses import categorical_crossentropy
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.regularizers import l2
|
||||
from tensorflow.random import set_seed
|
||||
|
||||
from hyperparameters import Hyperparameters
|
||||
from preprocessing import BASES, dataset_creation
|
||||
from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence
|
||||
|
||||
|
||||
def build_model(hyperparams) -> Model:
|
||||
|
@ -20,19 +23,15 @@ def build_model(hyperparams) -> Model:
|
|||
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
|
||||
Masking(mask_value=-1),
|
||||
Dense(
|
||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
),
|
||||
Dropout(rate=0.3),
|
||||
Dense(
|
||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
),
|
||||
Dropout(rate=0.3),
|
||||
Dense(
|
||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
),
|
||||
Dropout(rate=0.3),
|
||||
Dense(
|
||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||
),
|
||||
Dropout(rate=0.3),
|
||||
Dense(units=len(BASES), activation="softmax"),
|
||||
|
@ -56,7 +55,7 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
|
|||
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
|
||||
|
||||
|
||||
def run(data_file, label_file, seed_value=42) -> None:
|
||||
def train_model(data_file, label_file, seed_value=42) -> None:
|
||||
"""
|
||||
Create a dataset, a model and runs training and evaluation on it
|
||||
"""
|
||||
|
@ -69,3 +68,17 @@ def run(data_file, label_file, seed_value=42) -> None:
|
|||
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
|
||||
print("Training complete. Obtaining the model's metrics...")
|
||||
show_metrics(model, eval_data, test_data)
|
||||
model.save("trained_model")
|
||||
|
||||
|
||||
def infer_sequence(sequence) -> None:
|
||||
"""
|
||||
Predict the correct sequence, using the trained model
|
||||
"""
|
||||
model = load_model("trained_model")
|
||||
encoded_sequence = encode_sequence(sequence)
|
||||
one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES))
|
||||
prediction = model.predict(one_hot_encoded_sequence)
|
||||
encoded_prediction = argmax(prediction, axis=1)
|
||||
final_prediction = decode_sequence(encoded_prediction)
|
||||
print(f"Error-corrected sequence: {final_prediction}")
|
||||
|
|
|
@ -30,6 +30,15 @@ def encode_sequence(sequence) -> List[int]:
|
|||
return encoded_sequence
|
||||
|
||||
|
||||
def decode_sequence(sequence) -> str:
|
||||
"""
|
||||
Decode an index encoded sequence back to the human readable format
|
||||
"""
|
||||
decoded_list = [BASES[element] for element in sequence]
|
||||
sequence = "".join(decoded_list)
|
||||
return sequence
|
||||
|
||||
|
||||
def prepare_sequences(sequence, label):
|
||||
"""
|
||||
Align and encode the sequences to obtain a fixed length output in order to perform batching
|
||||
|
|
Loading…
Reference in New Issue