Сохранение модели кераса TF2 с пользовательскими определениями сигнатуры

У меня есть модель Keras (последовательная), которую можно сохранить с помощью настраиваемых определений подписи в Tensorflow 1.13 следующим образом:

from tensorflow.saved_model.utils import build_tensor_info
from tensorflow.saved_model.signature_def_utils import predict_signature_def, build_signature_def

model = Sequential() // with some layers

builder = tf.saved_model.builder.SavedModelBuilder(export_path)

score_signature = predict_signature_def(
    inputs={'waveform': model.input},
    outputs={'scores': model.output})

metadata = build_signature_def(
    outputs={'other_variable': build_tensor_info(tf.constant(1234, dtype=tf.int64))})

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  builder.add_meta_graph_and_variables(
      sess=sess,
      tags=[tf.saved_model.tag_constants.SERVING],
      signature_def_map={'score': score_signature, 'metadata': metadata})
  builder.save()

Миграция модели на TF2 keras была крутой :), но я не могу понять, как сохранить модель с той же подписью, что и выше. Что мне следует использовать: новый tf.saved_model.save() или tf.keras.experimental.export_saved_model()? Как следует написать приведенный выше код в TF2?

Ключевые требования:

  • Модель имеет подпись очков и подпись метаданных.
  • Подпись метаданных содержит 1 или несколько констант.

person Antony Harfield    schedule 19.06.2019    source источник


Ответы (1)


Решение состоит в том, чтобы создать tf.Module с функциями для каждого определения подписи:

class MyModule(tf.Module):
  def __init__(self, model, other_variable):
    self.model = model
    self._other_variable = other_variable

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)])
  def score(self, waveform):
    result = self.model(waveform)
    return { "scores": results }

  @tf.function(input_signature=[])
  def metadata(self):
    return { "other_variable": self._other_variable }

А затем сохраните модуль (не модель):

module = MyModule(model, 1234)
tf.saved_model.save(module, export_path, signatures={ "score": module.score, "metadata": module.metadata })

Протестировано с моделью Keras на TF2.

person Antony Harfield    schedule 03.07.2019
comment
Чрезвычайно полезный ответ! Некоторое время я рыскал по Интернету, прежде чем нашел его. Также интересно: я тестировал его, и вам даже не нужно оборачивать tf.Module для пользовательского подкласса модели Keras. Вы можете просто добавить функцию метаданных в том виде, в каком вы ее написали, и все работает вместе с tf.keras.models.save/load. Написано здесь: stackoverflow.com/questions/54642590/ - person Carson McNeil; 25.08.2020