Await prediction and print it in the caller

This commit is contained in:
coolneng 2021-07-06 17:53:45 +02:00
parent 3ded0744b3
commit 78acc54e5f
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
3 changed files with 5 additions and 4 deletions

View File

@ -1,5 +1,6 @@
from fastapi import FastAPI
from pydantic import BaseModel
from model import infer_sequence
app = FastAPI()

View File

@ -22,7 +22,8 @@ def execute_task(args):
if args.task == "train":
train_model(data_file=args.data_file, label_file=args.label_file)
else:
infer_sequence(sequence=args.sequence)
prediction = await infer_sequence(sequence=args.sequence)
print(f"Error-corrected sequence: {prediction}")
def main() -> None:

View File

@ -34,7 +34,7 @@ def build_model(hyperparams) -> Model:
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(units=len(BASES), activation="softmax"),
Dense(units=32, activation="softmax"),
]
)
model.compile(
@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
model.save("trained_model")
async def infer_sequence(sequence) -> None:
async def infer_sequence(sequence) -> str:
"""
Predict the correct sequence, using the trained model
"""
@ -81,5 +81,4 @@ async def infer_sequence(sequence) -> None:
prediction = model.predict(one_hot_encoded_sequence)
encoded_prediction = argmax(prediction, axis=1)
final_prediction = decode_sequence(encoded_prediction)
print(f"Error-corrected sequence: {final_prediction}")
return final_prediction