diff --git a/src/main.py b/src/main.py index 0ef6fd7..387cbb5 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,4 @@ +from asyncio import run from argparse import ArgumentParser, Namespace from model import infer_sequence, train_model @@ -18,17 +19,17 @@ def parse_arguments() -> Namespace: return parser.parse_args() -def execute_task(args): +async def execute_task(args): if args.task == "train": train_model(data_file=args.data_file, label_file=args.label_file) else: - prediction = infer_sequence(sequence=args.sequence) + prediction = await infer_sequence(sequence=args.sequence) print(f"Error-corrected sequence: {prediction}") def main() -> None: args = parse_arguments() - execute_task(args) + run(execute_task(args)) if __name__ == "__main__":