From 403fa231060dda617125cf244bf90f0cd96abd08 Mon Sep 17 00:00:00 2001 From: coolneng Date: Tue, 6 Jul 2021 06:21:32 +0200 Subject: [PATCH] Serve model via REST API --- pyproject.toml | 2 ++ src/api.py | 21 +++++++++++++++++++++ src/model.py | 7 ++++--- 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 src/api.py diff --git a/pyproject.toml b/pyproject.toml index 3a6f91d..b033590 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ license = "GPL-3.0-or-later" python = "3.8.*" tensorflow = "^2.4.1" biopython = "^1.78" +fastapi = "^0.66.0" +uvicorn = "^0.14.0" [tool.poetry.dev-dependencies] isort = "^5.8.0" diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..45a9319 --- /dev/null +++ b/src/api.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from pydantic import BaseModel +from model import infer_sequence + +app = FastAPI() + + +class Input(BaseModel): + sequence: str + + +@app.get("/{sequence}") +async def get_sequence_path(sequence: str): + correct_sequence = await infer_sequence(sequence) + return {"sequence": correct_sequence} + + +@app.post("/") +async def get_sequence_body(sequence: Input): + correct_sequence = await infer_sequence(sequence.sequence) + return {"sequence": correct_sequence} diff --git a/src/model.py b/src/model.py index 1b451ab..7e941dd 100644 --- a/src/model.py +++ b/src/model.py @@ -51,8 +51,8 @@ def show_metrics(model, eval_dataset, test_dataset) -> None: """ eval_metrics = model.evaluate(eval_dataset, verbose=0) test_metrics = model.evaluate(test_dataset, verbose=0) - print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}") - print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}") + print(f"Eval metrics {eval_metrics}") + print(f"Test metrics {test_metrics}") def train_model(data_file, label_file, seed_value=42) -> None: @@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None: model.save("trained_model") -def infer_sequence(sequence) -> None: +async def infer_sequence(sequence) -> None: """ Predict the correct sequence, using the trained model """ @@ -82,3 +82,4 @@ def infer_sequence(sequence) -> None: encoded_prediction = argmax(prediction, axis=1) final_prediction = decode_sequence(encoded_prediction) print(f"Error-corrected sequence: {final_prediction}") + return final_prediction