Await prediction and print it in the caller
This commit is contained in:
parent
3ded0744b3
commit
fba3c5318b
|
@ -1,5 +1,6 @@
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model import infer_sequence
|
from model import infer_sequence
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
|
@ -22,7 +22,8 @@ def execute_task(args):
|
||||||
if args.task == "train":
|
if args.task == "train":
|
||||||
train_model(data_file=args.data_file, label_file=args.label_file)
|
train_model(data_file=args.data_file, label_file=args.label_file)
|
||||||
else:
|
else:
|
||||||
infer_sequence(sequence=args.sequence)
|
prediction = infer_sequence(sequence=args.sequence)
|
||||||
|
print(f"Error-corrected sequence: {prediction}")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|
|
@ -34,7 +34,7 @@ def build_model(hyperparams) -> Model:
|
||||||
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
||||||
),
|
),
|
||||||
Dropout(rate=0.3),
|
Dropout(rate=0.3),
|
||||||
Dense(units=len(BASES), activation="softmax"),
|
Dense(units=32, activation="softmax"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
model.compile(
|
model.compile(
|
||||||
|
@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
|
||||||
model.save("trained_model")
|
model.save("trained_model")
|
||||||
|
|
||||||
|
|
||||||
async def infer_sequence(sequence) -> None:
|
async def infer_sequence(sequence) -> str:
|
||||||
"""
|
"""
|
||||||
Predict the correct sequence, using the trained model
|
Predict the correct sequence, using the trained model
|
||||||
"""
|
"""
|
||||||
|
@ -81,5 +81,4 @@ async def infer_sequence(sequence) -> None:
|
||||||
prediction = model.predict(one_hot_encoded_sequence)
|
prediction = model.predict(one_hot_encoded_sequence)
|
||||||
encoded_prediction = argmax(prediction, axis=1)
|
encoded_prediction = argmax(prediction, axis=1)
|
||||||
final_prediction = decode_sequence(encoded_prediction)
|
final_prediction = decode_sequence(encoded_prediction)
|
||||||
print(f"Error-corrected sequence: {final_prediction}")
|
|
||||||
return final_prediction
|
return final_prediction
|
||||||
|
|
Loading…
Reference in New Issue