Compare commits
1 Commits
8d7d2203b5
...
78acc54e5f
Author | SHA1 | Date |
---|---|---|
coolneng | 78acc54e5f |
|
@ -22,7 +22,8 @@ def execute_task(args):
|
|||
if args.task == "train":
|
||||
train_model(data_file=args.data_file, label_file=args.label_file)
|
||||
else:
|
||||
infer_sequence(sequence=args.sequence)
|
||||
prediction = await infer_sequence(sequence=args.sequence)
|
||||
print(f"Error-corrected sequence: {prediction}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
|
|||
model.save("trained_model")
|
||||
|
||||
|
||||
async def infer_sequence(sequence) -> None:
|
||||
async def infer_sequence(sequence) -> str:
|
||||
"""
|
||||
Predict the correct sequence, using the trained model
|
||||
"""
|
||||
|
@ -81,5 +81,4 @@ async def infer_sequence(sequence) -> None:
|
|||
prediction = model.predict(one_hot_encoded_sequence)
|
||||
encoded_prediction = argmax(prediction, axis=1)
|
||||
final_prediction = decode_sequence(encoded_prediction)
|
||||
print(f"Error-corrected sequence: {final_prediction}")
|
||||
return final_prediction
|
||||
|
|
Loading…
Reference in New Issue