Compare commits

..

1 Commits

Author SHA1 Message Date
coolneng 8d7d2203b5
Change the number of units in the last Dense layer 2021-07-06 17:49:39 +02:00
2 changed files with 3 additions and 3 deletions

View File

@ -22,8 +22,7 @@ def execute_task(args):
if args.task == "train":
train_model(data_file=args.data_file, label_file=args.label_file)
else:
prediction = await infer_sequence(sequence=args.sequence)
print(f"Error-corrected sequence: {prediction}")
infer_sequence(sequence=args.sequence)
def main() -> None:

View File

@ -71,7 +71,7 @@ def train_model(data_file, label_file, seed_value=42) -> None:
model.save("trained_model")
async def infer_sequence(sequence) -> str:
async def infer_sequence(sequence) -> None:
"""
Predict the correct sequence, using the trained model
"""
@ -81,4 +81,5 @@ async def infer_sequence(sequence) -> str:
prediction = model.predict(one_hot_encoded_sequence)
encoded_prediction = argmax(prediction, axis=1)
final_prediction = decode_sequence(encoded_prediction)
print(f"Error-corrected sequence: {final_prediction}")
return final_prediction