trt_exporter.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """ Export models. """
  2. import logging
  3. from model import MovierecModel
  4. import os
  5. import tensorflow as tf
  6. from tensorflow.python.keras import backend as K
  7. from tensorflow.python.framework import graph_io
  8. # For versions older than 1.14, use graph_utils from 1.14 (locally downloaded)
  9. tf_version = tf.__version__.split('.')
  10. if tf_version[0] == '1' and int(tf_version[1]) < 14:
  11. logging.warn("Importing graph_util from local copy because tensorflow version {} is lower than 1.14"
  12. .format(tf.__version__))
  13. from util import tf_graph_util
  14. else:
  15. import tensorflow.compat.v1.graph_util as tf_graph_util
  16. CONFIG_FILE_NAME = "config.pbtxt"
  17. MODEL_VERSION_DIR_NAME = '1'
  18. MAX_BATCH_SIZE = '1024'
  19. # some data types mappings from
  20. # https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html#section-datatypes
  21. DATA_TYPES_MAP = {
  22. tf.int16: "TYPE_INT16",
  23. tf.int32: "TYPE_INT32",
  24. tf.int64: "TYPE_INT64",
  25. tf.float16: "TYPE_FP16",
  26. tf.float32: "TYPE_FP32",
  27. tf.float64: "TYPE_FP64",
  28. }
  29. def export_keras_model_to_trt(input_dir, model_name, output_dir):
  30. """
  31. Export a saved keras model to a TensorRT-compatible model. Steps: load keras model from file,
  32. freeze and optimize graph for inference, save in TensorRT-compatible format.
  33. See:
  34. https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_repository.html#tensorflow-models
  35. Parameters
  36. ----------
  37. input_dir : str or `os.path`
  38. Directory of the keras model files.
  39. model_name : str
  40. Name of the model to load.
  41. output_dir : str or `os.path`
  42. Directory to save the exported model.
  43. """
  44. # set learning phase to 'test'
  45. K.set_learning_phase(0)
  46. model = MovierecModel.load_from_dir(input_dir, model_name).model
  47. with K.get_session() as session:
  48. graph = session.graph
  49. with graph.as_default():
  50. graph_def = graph.as_graph_def()
  51. # freeze graph and save to file
  52. frozen_graph = tf_graph_util.remove_training_nodes(graph_def)
  53. frozen_graph = tf_graph_util.convert_variables_to_constants(
  54. sess=session,
  55. input_graph_def=frozen_graph,
  56. output_node_names=[out.op.name for out in model.outputs])
  57. # dir structure and name as expected by TensorRT
  58. output_file_path = os.path.join(output_dir, model_name, MODEL_VERSION_DIR_NAME)
  59. graph_io.write_graph(
  60. graph_or_graph_def=frozen_graph,
  61. logdir=output_file_path,
  62. name='model.graphdef',
  63. as_text=False)
  64. logging.info("Saved graph def file to {}".format(output_file_path))
  65. # write config, setting only one output
  66. _write_config_file(output_dir, model_name, model.inputs, [model.outputs[0]])
  67. def _write_config_file(output_dir, model_name, inputs, outputs):
  68. """
  69. Write model configuration file.
  70. See
  71. https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html
  72. Parameters
  73. ----------
  74. output_dir : str or `os.path`
  75. Directory of the the exported model.
  76. model_name : str
  77. Name of the model
  78. inputs : List of `tf.Tensor`
  79. Model inputs.
  80. outputs: List of `tf.Tensor`
  81. Model outputs.
  82. """
  83. config_str = 'name: "{}"\n' \
  84. 'platform: "tensorflow_graphdef"\n' \
  85. 'max_batch_size: {}\n' \
  86. 'input [\n'.format(model_name, MAX_BATCH_SIZE)
  87. def get_tensors_str_list(tensors):
  88. tensors_str_list = []
  89. for tensor in tensors:
  90. name = tensor.op.name
  91. data_type = DATA_TYPES_MAP[tensor.dtype]
  92. dims = ' '.join([str(dim.value) for dim in tensor.shape[1:]])
  93. tensors_str_list.append(' {{\n'
  94. ' name: "{}"\n'
  95. ' data_type: {}\n'
  96. ' dims: [ {} ]\n'
  97. '\n }}'.format(name, data_type, dims))
  98. return tensors_str_list
  99. config_str += ','.join(get_tensors_str_list(inputs))
  100. config_str += '\n]\n' \
  101. 'output [\n'
  102. config_str += ','.join(get_tensors_str_list(outputs))
  103. config_str += '\n]\n'
  104. # write to file
  105. config_file_path = os.path.join(output_dir, model_name, CONFIG_FILE_NAME)
  106. with open(config_file_path, 'w') as f_out:
  107. f_out.write(config_str)
  108. logging.info("Saved config file to {}".format(config_file_path))
  109. if __name__ == '__main__':
  110. import argparse
  111. parser = argparse.ArgumentParser(description='Export keras model to TensortRT-compatible format.')
  112. parser.add_argument('-m', '--model-name', type=str, required=True,
  113. help='Model name')
  114. parser.add_argument('-i', '--input-model-dir', type=str, required=True,
  115. help='Input model directory (absolute or relative path)')
  116. parser.add_argument('-o', '--output-dir', type=str, required=True,
  117. help='Output directory (absolute or relative path)')
  118. parser.add_argument('-l', '--log-level', type=str, default='INFO',
  119. help='Log level (default: INFO).')
  120. args = parser.parse_args()
  121. logging.getLogger().setLevel(logging.getLevelName(args.log_level))
  122. export_keras_model_to_trt(args.input_model_dir, args.model_name, args.output_dir)