31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
|
import sys
|
||
|
import os
|
||
|
|
||
|
def explain():
|
||
|
print('This converter can be used to transform keras models in the .h5 format into a tflite flatbuffer model .tflite\n')
|
||
|
print('\tIt must be called like in this exampple: python keras2tflite.py /path/to/input_model.h5 ./output_model.tflite\n')
|
||
|
print('\tArguments expected: 2\n\tArguments received: ' + str(len(sys.argv))-1)
|
||
|
versions()
|
||
|
|
||
|
def versions(k=''):
|
||
|
print(k + 'Python version >= 3.x')
|
||
|
print(k + 'Keras model version >= 2.2.4')
|
||
|
print(k + 'TensorFlow Lite version >= 1.11')
|
||
|
print(k + 'TensorFlow version >= 1.14.x')
|
||
|
|
||
|
if len(sys.argv) != 3:
|
||
|
explain()
|
||
|
|
||
|
else:
|
||
|
# Converting a tf.Keras model to a TensorFlow Lite model.
|
||
|
try:
|
||
|
os.system('tflite_convert \
|
||
|
--keras_model_file='+sys.argv[1]+' \
|
||
|
--output_file='+sys.argv[2])
|
||
|
except Exception as e:
|
||
|
print('\nError in the conversion to TFLite:\n' + str(e) + '\n')
|
||
|
except IOError as e:
|
||
|
print('\nError when trying to save the model:\n' + str(e) +'\n')
|
||
|
else:
|
||
|
print('\nModel saved!\n')
|