Serve model via REST API

This commit is contained in:
coolneng 2021-07-06 06:21:32 +02:00
parent f91abfe43d
commit 403fa23106
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
3 changed files with 27 additions and 3 deletions

View File

@ -9,6 +9,8 @@ license = "GPL-3.0-or-later"
python = "3.8.*" python = "3.8.*"
tensorflow = "^2.4.1" tensorflow = "^2.4.1"
biopython = "^1.78" biopython = "^1.78"
fastapi = "^0.66.0"
uvicorn = "^0.14.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
isort = "^5.8.0" isort = "^5.8.0"

21
src/api.py Normal file
View File

@ -0,0 +1,21 @@
from fastapi import FastAPI
from pydantic import BaseModel
from model import infer_sequence
app = FastAPI()
class Input(BaseModel):
sequence: str
@app.get("/{sequence}")
async def get_sequence_path(sequence: str):
correct_sequence = await infer_sequence(sequence)
return {"sequence": correct_sequence}
@app.post("/")
async def get_sequence_body(sequence: Input):
correct_sequence = await infer_sequence(sequence.sequence)
return {"sequence": correct_sequence}

View File

@ -51,8 +51,8 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
""" """
eval_metrics = model.evaluate(eval_dataset, verbose=0) eval_metrics = model.evaluate(eval_dataset, verbose=0)
test_metrics = model.evaluate(test_dataset, verbose=0) test_metrics = model.evaluate(test_dataset, verbose=0)
print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}") print(f"Eval metrics {eval_metrics}")
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}") print(f"Test metrics {test_metrics}")
def train_model(data_file, label_file, seed_value=42) -> None: def train_model(data_file, label_file, seed_value=42) -> None:
@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
model.save("trained_model") model.save("trained_model")
def infer_sequence(sequence) -> None: async def infer_sequence(sequence) -> None:
""" """
Predict the correct sequence, using the trained model Predict the correct sequence, using the trained model
""" """
@ -82,3 +82,4 @@ def infer_sequence(sequence) -> None:
encoded_prediction = argmax(prediction, axis=1) encoded_prediction = argmax(prediction, axis=1)
final_prediction = decode_sequence(encoded_prediction) final_prediction = decode_sequence(encoded_prediction)
print(f"Error-corrected sequence: {final_prediction}") print(f"Error-corrected sequence: {final_prediction}")
return final_prediction