Implement model inference of sequences
This commit is contained in:
parent
1333a9256b
commit
92c6b54966
23
src/main.py
23
src/main.py
|
@ -1,26 +1,33 @@
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from time import time
|
|
||||||
|
|
||||||
from model import run
|
from model import train_model, infer_sequence
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> Namespace:
|
def parse_arguments() -> Namespace:
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument(
|
subparsers = parser.add_subparsers(dest="task")
|
||||||
|
parser_train = subparsers.add_parser("train")
|
||||||
|
parser_infer = subparsers.add_parser("infer")
|
||||||
|
parser_train.add_argument(
|
||||||
"data_file", help="FASTQ file containing the sequences with errors"
|
"data_file", help="FASTQ file containing the sequences with errors"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser_train.add_argument(
|
||||||
"label_file", help="FASTQ file containing the sequences without errors"
|
"label_file", help="FASTQ file containing the sequences without errors"
|
||||||
)
|
)
|
||||||
|
parser_infer.add_argument("sequence", help="DNA sequence with errors")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def execute_task(args):
|
||||||
|
if args.task == "train":
|
||||||
|
train_model(data_file=args.data_file, label_file=args.label_file)
|
||||||
|
else:
|
||||||
|
infer_sequence(sequence=args.sequence)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
start_time = time()
|
execute_task(args)
|
||||||
run(data_file=args.data_file, label_file=args.label_file)
|
|
||||||
end_time = time()
|
|
||||||
print(f"Elapsed time: {end_time - start_time}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
33
src/model.py
33
src/model.py
|
@ -1,14 +1,17 @@
|
||||||
from random import seed
|
from random import seed
|
||||||
|
|
||||||
|
from numpy import argmax
|
||||||
|
from tensorflow import one_hot
|
||||||
from tensorflow.keras import Model, Sequential
|
from tensorflow.keras import Model, Sequential
|
||||||
from tensorflow.keras.layers import *
|
from tensorflow.keras.layers import Dense, Dropout, Input, Masking
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
from tensorflow.keras.losses import categorical_crossentropy
|
from tensorflow.keras.losses import categorical_crossentropy
|
||||||
from tensorflow.keras.optimizers import Adam
|
from tensorflow.keras.optimizers import Adam
|
||||||
from tensorflow.keras.regularizers import l2
|
from tensorflow.keras.regularizers import l2
|
||||||
from tensorflow.random import set_seed
|
from tensorflow.random import set_seed
|
||||||
|
|
||||||
from hyperparameters import Hyperparameters
|
from hyperparameters import Hyperparameters
|
||||||
from preprocessing import BASES, dataset_creation
|
from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence
|
||||||
|
|
||||||
|
|
||||||
def build_model(hyperparams) -> Model:
|
def build_model(hyperparams) -> Model:
|
||||||
|
@ -20,19 +23,15 @@ def build_model(hyperparams) -> Model:
|
||||||
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
|
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
|
||||||
Masking(mask_value=-1),
|
Masking(mask_value=-1),
|
||||||
Dense(
|
Dense(
|
||||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||||
),
|
),
|
||||||
Dropout(rate=0.3),
|
Dropout(rate=0.3),
|
||||||
Dense(
|
Dense(
|
||||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||||
),
|
),
|
||||||
Dropout(rate=0.3),
|
Dropout(rate=0.3),
|
||||||
Dense(
|
Dense(
|
||||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||||
),
|
|
||||||
Dropout(rate=0.3),
|
|
||||||
Dense(
|
|
||||||
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
|
||||||
),
|
),
|
||||||
Dropout(rate=0.3),
|
Dropout(rate=0.3),
|
||||||
Dense(units=len(BASES), activation="softmax"),
|
Dense(units=len(BASES), activation="softmax"),
|
||||||
|
@ -56,7 +55,7 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
|
||||||
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
|
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
|
||||||
|
|
||||||
|
|
||||||
def run(data_file, label_file, seed_value=42) -> None:
|
def train_model(data_file, label_file, seed_value=42) -> None:
|
||||||
"""
|
"""
|
||||||
Create a dataset, a model and runs training and evaluation on it
|
Create a dataset, a model and runs training and evaluation on it
|
||||||
"""
|
"""
|
||||||
|
@ -69,3 +68,17 @@ def run(data_file, label_file, seed_value=42) -> None:
|
||||||
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
|
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
|
||||||
print("Training complete. Obtaining the model's metrics...")
|
print("Training complete. Obtaining the model's metrics...")
|
||||||
show_metrics(model, eval_data, test_data)
|
show_metrics(model, eval_data, test_data)
|
||||||
|
model.save("trained_model")
|
||||||
|
|
||||||
|
|
||||||
|
def infer_sequence(sequence) -> None:
|
||||||
|
"""
|
||||||
|
Predict the correct sequence, using the trained model
|
||||||
|
"""
|
||||||
|
model = load_model("trained_model")
|
||||||
|
encoded_sequence = encode_sequence(sequence)
|
||||||
|
one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES))
|
||||||
|
prediction = model.predict(one_hot_encoded_sequence)
|
||||||
|
encoded_prediction = argmax(prediction, axis=1)
|
||||||
|
final_prediction = decode_sequence(encoded_prediction)
|
||||||
|
print(f"Error-corrected sequence: {final_prediction}")
|
||||||
|
|
|
@ -30,6 +30,15 @@ def encode_sequence(sequence) -> List[int]:
|
||||||
return encoded_sequence
|
return encoded_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def decode_sequence(sequence) -> str:
|
||||||
|
"""
|
||||||
|
Decode an index encoded sequence back to the human readable format
|
||||||
|
"""
|
||||||
|
decoded_list = [BASES[element] for element in sequence]
|
||||||
|
sequence = "".join(decoded_list)
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def prepare_sequences(sequence, label):
|
def prepare_sequences(sequence, label):
|
||||||
"""
|
"""
|
||||||
Align and encode the sequences to obtain a fixed length output in order to perform batching
|
Align and encode the sequences to obtain a fixed length output in order to perform batching
|
||||||
|
|
Loading…
Reference in New Issue