Как загрузить контрольные точки графа (.ckpt) и использовать SavedModelBuilder, чтобы сохранить его как protobuf без объявления каких-либо tf.Variables?

В настоящее время у меня есть resnet_v2_50.ckpt из предварительно обученной модели с открытым исходным кодом tensorflow. Я пытаюсь использовать эту модель в Go, потому что мой бэкенд для моего веб-приложения будет в Go. Если бы я создал свою собственную модель и обучил ее, а затем сохранил. У меня нет проблем с обслуживанием в Go, но я пытаюсь использовать предварительно обученную модель, чтобы сэкономить время с моей стороны.

Вот простой пример того, как я сохраняю свою модель

mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)

# Recall that each image is 28x28
x = tf.placeholder(tf.float32, [None, 784], name='imageinput')
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.add(tf.matmul(x, W), b)
labels = tf.placeholder(tf.float32, [None, 10])
cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy_loss)

with tf.Session() as sess:
    with tf.device("/cpu:0"):
        sess.run(tf.global_variables_initializer())
        for i in range(1000):
            batch_x, batch_label = mnist.train.next_batch(100)
            loss, _ = sess.run([cross_entropy_loss, train_step], feed_dict={x: batch_x, labels: batch_label})
            print '%d: %f' % (i + 1, loss)

        infer = tf.argmax(y, axis=1, name='infer')
        truth = tf.argmax(labels, axis=1)
        correct_prediction = tf.equal(infer, truth)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print sess.run(accuracy, feed_dict={x: mnist.test.images, labels: mnist.test.labels})

        print 'Time to save the graph!'
        builder = tf.saved_model.builder.SavedModelBuilder('mnist_model')
        builder.add_meta_graph_and_variables(sess, ['serve'])
        builder.save()

Я могу загрузить его с помощью tensorflow в Go

model, err := tf.LoadSavedModel("./tf_mnist_py/mnist_model", []string{"serve"}, nil)
if err != nil {
    fmt.Printf("Error loading saved model: %s\n", err.Error())
    return
}

defer model.Session.Close()

Но теперь, когда дело доходит до предварительно обученной модели, я имею дело с ckpt файлами. У меня есть одно решение: загрузить его в Python, а затем сохранить как protobuf.

from tensorflow.python.tools import inspect_checkpoint as ckpt
ckpt.print_tensors_in_checkpoint_file('./resnet50/resnet_v2_50.ckpt',
                                      tensor_name='',
                                      all_tensors=False,
                                      all_tensor_names=False)

tf.reset_default_graph()
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './resnet50/resnet_v2_50.ckpt')
    print 'Model is restored'

    print 'Time to save the graph!'
    builder = tf.saved_model.builder.SavedModelBuilder('resnet_50_model')
    builder.add_meta_graph_and_variables(sess, ['serve'])
    builder.save()

Однако это дает мне сообщение об ошибке, что ValueError: No variables to save. Я могу исправить это, объявив переменную

v1 = tf.get_variable('total_loss/ExponentialMovingAverage', shape=[])

Но вот мой вопрос: означает ли это, что я должен объявить КАЖДУЮ переменную в ResNet50 и заставить tensorflow загрузить значения из файла ckpt в эти переменные, а затем выполнить сохранение? Есть ли способ сделать это?


person mofury    schedule 31.03.2018    source источник


Ответы (1)


Вы также должны импортировать переменные из файла .meta, если он доступен с предварительно обученной моделью, например

saver = tf.train.import_meta_graph('./resnet50/resnet_v2_50.meta')

Подробное руководство здесь.

Если у вас нет доступного .meta, но у вас есть код генерации сети, например, с resnet_v2_50 в tensorflow/models/blob/master/research/slim/nets/resnet_v2.py, то вам следует импортировать этот файл и запустить функцию resnet_v2_50, которая определит все переменные для ты. Затем восстановите контрольную точку.

person Peter Szoldan    schedule 31.03.2018
comment
Я заметил, что файл .meta помогает мне восстановить график в tensorflow (со всеми его переменными), но я не смог найти метаграф в папке resnet_v2_50, которую я скачал с github.com/tensorflow/models/tree/master/research/slim, какой важный шаг я пропустил? У меня в папке есть train.graph и eval.graph. - person mofury; 01.04.2018
comment
Отредактировал свой пост как ответ на ваш комментарий; добавлена ​​ссылка на код и инструкции resnet_v2_50. - person Peter Szoldan; 01.04.2018