Serve model via REST API
This commit is contained in:
parent
f91abfe43d
commit
403fa23106
|
@ -9,6 +9,8 @@ license = "GPL-3.0-or-later"
|
||||||
python = "3.8.*"
|
python = "3.8.*"
|
||||||
tensorflow = "^2.4.1"
|
tensorflow = "^2.4.1"
|
||||||
biopython = "^1.78"
|
biopython = "^1.78"
|
||||||
|
fastapi = "^0.66.0"
|
||||||
|
uvicorn = "^0.14.0"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
isort = "^5.8.0"
|
isort = "^5.8.0"
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from model import infer_sequence
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class Input(BaseModel):
|
||||||
|
sequence: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/{sequence}")
|
||||||
|
async def get_sequence_path(sequence: str):
|
||||||
|
correct_sequence = await infer_sequence(sequence)
|
||||||
|
return {"sequence": correct_sequence}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/")
|
||||||
|
async def get_sequence_body(sequence: Input):
|
||||||
|
correct_sequence = await infer_sequence(sequence.sequence)
|
||||||
|
return {"sequence": correct_sequence}
|
|
@ -51,8 +51,8 @@ def show_metrics(model, eval_dataset, test_dataset) -> None:
|
||||||
"""
|
"""
|
||||||
eval_metrics = model.evaluate(eval_dataset, verbose=0)
|
eval_metrics = model.evaluate(eval_dataset, verbose=0)
|
||||||
test_metrics = model.evaluate(test_dataset, verbose=0)
|
test_metrics = model.evaluate(test_dataset, verbose=0)
|
||||||
print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}")
|
print(f"Eval metrics {eval_metrics}")
|
||||||
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
|
print(f"Test metrics {test_metrics}")
|
||||||
|
|
||||||
|
|
||||||
def train_model(data_file, label_file, seed_value=42) -> None:
|
def train_model(data_file, label_file, seed_value=42) -> None:
|
||||||
|
@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
|
||||||
model.save("trained_model")
|
model.save("trained_model")
|
||||||
|
|
||||||
|
|
||||||
def infer_sequence(sequence) -> None:
|
async def infer_sequence(sequence) -> None:
|
||||||
"""
|
"""
|
||||||
Predict the correct sequence, using the trained model
|
Predict the correct sequence, using the trained model
|
||||||
"""
|
"""
|
||||||
|
@ -82,3 +82,4 @@ def infer_sequence(sequence) -> None:
|
||||||
encoded_prediction = argmax(prediction, axis=1)
|
encoded_prediction = argmax(prediction, axis=1)
|
||||||
final_prediction = decode_sequence(encoded_prediction)
|
final_prediction = decode_sequence(encoded_prediction)
|
||||||
print(f"Error-corrected sequence: {final_prediction}")
|
print(f"Error-corrected sequence: {final_prediction}")
|
||||||
|
return final_prediction
|
||||||
|
|
Loading…
Reference in New Issue