diff --git a/src/main.py b/src/main.py index 0fa69ef..0350c45 100644 --- a/src/main.py +++ b/src/main.py @@ -1,26 +1,33 @@ from argparse import ArgumentParser, Namespace -from time import time -from model import run +from model import train_model, infer_sequence def parse_arguments() -> Namespace: parser = ArgumentParser() - parser.add_argument( + subparsers = parser.add_subparsers(dest="task") + parser_train = subparsers.add_parser("train") + parser_infer = subparsers.add_parser("infer") + parser_train.add_argument( "data_file", help="FASTQ file containing the sequences with errors" ) - parser.add_argument( + parser_train.add_argument( "label_file", help="FASTQ file containing the sequences without errors" ) + parser_infer.add_argument("sequence", help="DNA sequence with errors") return parser.parse_args() +def execute_task(args): + if args.task == "train": + train_model(data_file=args.data_file, label_file=args.label_file) + else: + infer_sequence(sequence=args.sequence) + + def main() -> None: args = parse_arguments() - start_time = time() - run(data_file=args.data_file, label_file=args.label_file) - end_time = time() - print(f"Elapsed time: {end_time - start_time}") + execute_task(args) if __name__ == "__main__": diff --git a/src/model.py b/src/model.py index 30f2ff8..b66dcb2 100644 --- a/src/model.py +++ b/src/model.py @@ -1,14 +1,17 @@ from random import seed +from numpy import argmax +from tensorflow import one_hot from tensorflow.keras import Model, Sequential -from tensorflow.keras.layers import * +from tensorflow.keras.layers import Dense, Dropout, Input, Masking +from tensorflow.keras.models import load_model from tensorflow.keras.losses import categorical_crossentropy from tensorflow.keras.optimizers import Adam from tensorflow.keras.regularizers import l2 from tensorflow.random import set_seed from hyperparameters import Hyperparameters -from preprocessing import BASES, dataset_creation +from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence def build_model(hyperparams) -> Model: @@ -20,19 +23,15 @@ def build_model(hyperparams) -> Model: Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))), Masking(mask_value=-1), Dense( - units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) + units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) ), Dropout(rate=0.3), Dense( - units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) + units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) ), Dropout(rate=0.3), Dense( - units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) - ), - Dropout(rate=0.3), - Dense( - units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) + units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) ), Dropout(rate=0.3), Dense(units=len(BASES), activation="softmax"), @@ -56,7 +55,7 @@ def show_metrics(model, eval_dataset, test_dataset) -> None: print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}") -def run(data_file, label_file, seed_value=42) -> None: +def train_model(data_file, label_file, seed_value=42) -> None: """ Create a dataset, a model and runs training and evaluation on it """ @@ -69,3 +68,17 @@ def run(data_file, label_file, seed_value=42) -> None: model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data) print("Training complete. Obtaining the model's metrics...") show_metrics(model, eval_data, test_data) + model.save("trained_model") + + +def infer_sequence(sequence) -> None: + """ + Predict the correct sequence, using the trained model + """ + model = load_model("trained_model") + encoded_sequence = encode_sequence(sequence) + one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES)) + prediction = model.predict(one_hot_encoded_sequence) + encoded_prediction = argmax(prediction, axis=1) + final_prediction = decode_sequence(encoded_prediction) + print(f"Error-corrected sequence: {final_prediction}") diff --git a/src/preprocessing.py b/src/preprocessing.py index 5a1e96d..28efb54 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -30,6 +30,15 @@ def encode_sequence(sequence) -> List[int]: return encoded_sequence +def decode_sequence(sequence) -> str: + """ + Decode an index encoded sequence back to the human readable format + """ + decoded_list = [BASES[element] for element in sequence] + sequence = "".join(decoded_list) + return sequence + + def prepare_sequences(sequence, label): """ Align and encode the sequences to obtain a fixed length output in order to perform batching