Экспорт графика тензорного потока с помощью export_saved_model

Я пытаюсь обучить и развернуть упрощенный Quick, Draw! классификатор из здесь в Google Cloud. Мне удалось обучить модель в GC, теперь я застрял на ее развертывании, точнее на создание обслуживающих функций ввода.

Я следую инструкциям из здесь, и мне трудно понять, какой тип ввода тензор должен быть.

Ошибка:

TypeError: не удалось преобразовать объект типа в тензор. Содержание: SparseTensor (индексы = Tensor ("ParseExample / ParseExample: 0", shape = (?, 2), dtype = int64), values ​​= Tensor ("ParseExample / ParseExample: 1", shape = (?,), dtype = float32), density_shape = Tensor ("ParseExample / ParseExample: 2", shape = (2,), dtype = int64)). Рассмотрите возможность приведения элементов к поддерживаемому типу.

Функция обслуживания:

def serving_input_receiver_fn():
  serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_tensors')
  receiver_tensors = {'infer_inputs': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

Спецификация функции:

feature_spec = {
    "ink": tf.VarLenFeature(dtype=tf.float32),
    "shape": tf.FixedLenFeature([2], dtype=tf.int64)
}

Входной слой:

def _get_input_tensors(features, labels):
  shapes = features["shape"]
  lengths = tf.squeeze(
    tf.slice(shapes, begin=[0, 0], size=[params.batch_size, 1]))
  inks = tf.reshape(features["ink"], [params.batch_size, -1, 3])

  if labels is not None:
    labels = tf.squeeze(labels)
  return inks, lengths, labels

Код модели и данные обучения были взяты здесь.


person theuses    schedule 13.11.2018    source источник


Ответы (1)


Попробуй это:

def serving_input_receiver_fn():
  ink = tf.placeholder(dtype=tf.float32, shape=[None, None, 3], name='ink')
  length = tf.placeholder(dtype=tf.int64, shape=[None, 1])
  features = {"ink": inks, "length": lengths}
  return tf.estimator.export.ServingInputReceiver(features, features)

Пример полезной нагрузки:

{"instances": [{"ink": [[0.1, 1.0, 2.0]], "length":[[1]]}]}

или как ввод в gcloud predict --json-instances:

{"ink": [[0.1, 1.0, 2.0]], "length":[[1]]}]

Я не изучал реальный код; если чернила обычно содержат много плавающих символов, вы можете рассмотреть альтернативную систему кодирования.

person rhaertel80    schedule 13.11.2018
comment
Спасибо, добавили форму к функциям и отключили добавление потерь и оптимизаторов для режима прогнозирования, и это сработало! - person theuses; 13.11.2018
comment
рад это слышать! - person rhaertel80; 16.11.2018