Serve model via REST API
This commit is contained in:
parent
f91abfe43d
commit
403fa23106
|
@ -9,6 +9,8 @@ license = "GPL-3.0-or-later"
|
|||
python = "3.8.*"
|
||||
tensorflow = "^2.4.1"
|
||||
biopython = "^1.78"
|
||||
fastapi = "^0.66.0"
|
||||
uvicorn = "^0.14.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
isort = "^5.8.0"
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from model import infer_sequence
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
sequence: str
|
||||
|
||||
|
||||
@app.get("/{sequence}")
|
||||
async def get_sequence_path(sequence: str):
|
||||
correct_sequence = await infer_sequence(sequence)
|
||||
return {"sequence": correct_sequence}
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def get_sequence_body(sequence: Input):
|
||||
correct_sequence = await infer_sequence(sequence.sequence)
|
||||
return {"sequence": correct_sequence}
|
|
@ -51,8 +51,8 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
|
|||
"""
|
||||
eval_metrics = model.evaluate(eval_dataset, verbose=0)
|
||||
test_metrics = model.evaluate(test_dataset, verbose=0)
|
||||
print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}")
|
||||
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
|
||||
print(f"Eval metrics {eval_metrics}")
|
||||
print(f"Test metrics {test_metrics}")
|
||||
|
||||
|
||||
def train_model(data_file, label_file, seed_value=42) -> None:
|
||||
|
@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
|
|||
model.save("trained_model")
|
||||
|
||||
|
||||
def infer_sequence(sequence) -> None:
|
||||
async def infer_sequence(sequence) -> None:
|
||||
"""
|
||||
Predict the correct sequence, using the trained model
|
||||
"""
|
||||
|
@ -82,3 +82,4 @@ def infer_sequence(sequence) -> None:
|
|||
encoded_prediction = argmax(prediction, axis=1)
|
||||
final_prediction = decode_sequence(encoded_prediction)
|
||||
print(f"Error-corrected sequence: {final_prediction}")
|
||||
return final_prediction
|
||||
|
|
Loading…
Reference in New Issue