From fba3c5318b7b75cf65db971f4d951f57ee35a10d Mon Sep 17 00:00:00 2001 From: coolneng Date: Tue, 6 Jul 2021 17:53:45 +0200 Subject: [PATCH] Await prediction and print it in the caller --- src/api.py | 1 + src/main.py | 3 ++- src/model.py | 5 ++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/api.py b/src/api.py index 45a9319..1f069b5 100644 --- a/src/api.py +++ b/src/api.py @@ -1,5 +1,6 @@ from fastapi import FastAPI from pydantic import BaseModel + from model import infer_sequence app = FastAPI() diff --git a/src/main.py b/src/main.py index e0b8990..0ef6fd7 100644 --- a/src/main.py +++ b/src/main.py @@ -22,7 +22,8 @@ def execute_task(args): if args.task == "train": train_model(data_file=args.data_file, label_file=args.label_file) else: - infer_sequence(sequence=args.sequence) + prediction = infer_sequence(sequence=args.sequence) + print(f"Error-corrected sequence: {prediction}") def main() -> None: diff --git a/src/model.py b/src/model.py index 7e941dd..e113fb0 100644 --- a/src/model.py +++ b/src/model.py @@ -34,7 +34,7 @@ def build_model(hyperparams) -> Model: units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) ), Dropout(rate=0.3), - Dense(units=len(BASES), activation="softmax"), + Dense(units=32, activation="softmax"), ] ) model.compile( @@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None: model.save("trained_model") -async def infer_sequence(sequence) -> None: +async def infer_sequence(sequence) -> str: """ Predict the correct sequence, using the trained model """ @@ -81,5 +81,4 @@ async def infer_sequence(sequence) -> None: prediction = model.predict(one_hot_encoded_sequence) encoded_prediction = argmax(prediction, axis=1) final_prediction = decode_sequence(encoded_prediction) - print(f"Error-corrected sequence: {final_prediction}") return final_prediction