Вывод на предварительно обученную модель ONNX из Unity ml-агентов в Tensorflow

У меня есть предварительно обученная модель от ml-агентов Unity. Теперь я пытаюсь сделать вывод с этой моделью на Python с помощью TensorFlow. Для этого я использую TensorFlow Backend для ONNX, чтобы сохранить модель ONNX как SavedModel, чтобы я мог позже загрузите эту модель. Код, используемый для сохранения модели:

import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load(model_path)  # load onnx model
tf_rep = prepare(onnx_model, logging_level='DEBUG')
tf_rep.export_graph(output_path)

код для загрузки модели и запуска тестового примера

imported = tf.saved_model.load(
     model_dir, tags=None, options=None
)
f = imported.signatures["serving_default"]
print(f(visual_observation_0=tf.cast(forward, tf.float32), 
          visual_observation_1=tf.cast(body, tf.float32)))

Теперь есть несколько вопросов.

  1. результат теста имеет 6 выходных значений. (см. изображение ниже для визуального представления файла ONNX)
  2. При попытке сохранить модель я получил следующее сообщение (см. Отладочную информацию ниже)

Не уверен, что здесь происходит. Любая помощь очень ценится  Визуальная модель

2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Unknown op Celu in domain 'ai.onnx'.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of ConcatFromSequence in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Unknown op ConstantFill in domain 'ai.onnx'.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of ConvInteger in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of CumSum in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of DequantizeLinear in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of Det in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of DynamicQuantizeLinear in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op Einsum in domain 'ai.onnx'.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of GatherElements in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of GatherND in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op GreaterOrEqual in domain 'ai.onnx'.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op ImageScaler in domain 'ai.onnx'.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of IsInf in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op LessOrEqual in domain 'ai.onnx'.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of MatMulInteger in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of Mod in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of NonMaxSuppression in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of QLinearConv in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of QLinearMatMul in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of QuantizeLinear in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of Range in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of Resize in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of ReverseSequence in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of RoiAlign in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of Round in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of ScatterElements in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of ScatterND in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceAt in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceConstruct in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceEmpty in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceErase in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceInsert in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceLength in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SplitToSequence in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,271 - onnx-tf - DEBUG - Fail to get since_version of ThresholdedRelu in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03.273323: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-03-24 17:52:03.286901: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f912d05cf60 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-03-24 17:52:03.286913: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2021-03-24 17:52:07.450878: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

person theVortr3x    schedule 24.03.2021    source источник


Ответы (1)


Окей, так получается, что на выходе сеть тоже выдает версию и другие параметры.

onnx_model = onnx.load(model_path)  # load onnx model
tf_rep = prepare(onnx_model)

print(tf_rep.inputs) # Input nodes to the model
> output: ['visual_observation_0', 'visual_observation_1']
print(tf_rep.outputs) # Output nodes from the model
> output: ['version_number', 'memory_size', 'continuous_actions', 'continuous_action_output_shape', 'action', 'is_continuous_control', 'action_output_shape']

Вход был таким, как я ожидал. Однако на выходе также есть номер версии, память и так далее. Меня интересует только continuous_actions. Мне также пришлось масштабировать изображения от [0, 1]

person theVortr3x    schedule 25.03.2021