Tensorflow — несовместимая форма

Классифицируйте цифры MNIST с помощью Tensorflow с помощью двухуровневого подхода RNN. Обучение работает нормально, но при оценке точности сообщается о несовместимой форме тестовых данных.

import tensorflow as tf
import inspect
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)

hm_epochs = 1
n_classes = 10
batch_size = 128
chunk_size = 28
n_chunks = 28
rnn_size = 128

x = tf.placeholder('float', [None, n_chunks,chunk_size])
y = tf.placeholder('float')

def lstm_cell():
      if 'reuse' in inspect.getargspec(
          tf.contrib.rnn.BasicLSTMCell.__init__).args:
        return tf.contrib.rnn.BasicLSTMCell(
            rnn_size, forget_bias=0.0, state_is_tuple=True,
            reuse=tf.get_variable_scope().reuse)
      else:
        return tf.contrib.rnn.BasicLSTMCell(
            rnn_size, forget_bias=0.0, state_is_tuple=True)

def attn_cell():
        return tf.contrib.rnn.DropoutWrapper(
            lstm_cell())

def recurrent_neural_network(x):
    layer = {'weights':tf.Variable(tf.random_normal([rnn_size,n_classes])),
             'biases':tf.Variable(tf.random_normal([n_classes]))}

    x = tf.transpose(x, [1,0,2])
    x = tf.reshape(x, [-1, chunk_size])
    x = tf.split(x, n_chunks, 0)

    stacked_lstm = tf.contrib.rnn.MultiRNNCell([attn_cell(),attn_cell()], state_is_tuple=True)

    initial_state = state = stacked_lstm.zero_state(batch_size, tf.float32)   
    outputs, states = tf.contrib.rnn.static_rnn(stacked_lstm, x,state)
    output = tf.matmul(outputs[-1],layer['weights']) + layer['biases']

    return output

def train_neural_network(x):
    prediction = recurrent_neural_network(x)

    cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction) )
    optimizer = tf.train.AdamOptimizer().minimize(cost)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())

        for epoch in range(hm_epochs):
            epoch_loss = 0
            for _ in range(int(mnist.train.num_examples/batch_size)):
                epoch_x, epoch_y = mnist.train.next_batch(batch_size)
                epoch_x = epoch_x.reshape((batch_size,n_chunks,chunk_size))

                _, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
                epoch_loss += c

            print('Epoch', epoch, 'completed out of',hm_epochs,'loss:',epoch_loss)

        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        testdata= np.reshape( mnist.test.images, (10000, n_chunks, chunk_size))
        print("Testdata ",testdata.shape)
        print("x ",x)
        print('Accuracy:',accuracy.eval({x:testdata, y:mnist.test.labels}))

train_neural_network(x)

Однако формы тестовых данных и заполнителей печатаются следующим образом. Разве они не совместимы?

Epoch 0 completed out of 1 loss: 228.159379691
Testdata  (10000, 28, 28)
x  Tensor("Placeholder:0", shape=(?, 28, 28), dtype=float32)

Ошибка:

Caused by op 'rnn/rnn/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/basic_lstm_ce
ll/concat', defined at:
  File "main.py", line 90, in <module>
    train_neural_network(x)
  File "main.py", line 59, in train_neural_network
    prediction = recurrent_neural_network(x)
  File "main.py", line 52, in recurrent_neural_network
    outputs, states = tf.contrib.rnn.static_rnn(stacked_lstm, x,state)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py"
, line 1212, in static_rnn
    (output, state) = call_cell()
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py"
, line 1199, in <lambda>
    call_cell = lambda: cell(input_, state)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cel
l_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\layers\base
.py", line 441, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cel
l_impl.py", line 916, in call
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cel
l_impl.py", line 752, in __call__
    output, new_state = self._cell(inputs, state, scope)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cel
l_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\layers\base
.py", line 441, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cel
l_impl.py", line 383, in call
    concat = _linear([inputs, h], 4 * self._num_units, True)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cel
l_impl.py", line 1021, in _linear
    res = math_ops.matmul(array_ops.concat(args, 1), weights)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\array_o
ps.py", line 1048, in concat
    name=name)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_arr
ay_ops.py", line 495, in _concat_v2
    name=name)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\framework\o
p_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\framework\o
ps.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "C:\Users\henry\Anaconda3\lib\site-packages\tensorflow\python\framework\o
ps.py", line 1269, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs
should match: shape[0] = [10000,28] vs. shape[1] = [128,128]
         [[Node: rnn/rnn/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/basic_lstm
_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/
replica:0/task:0/cpu:0"](split, MultiRNNCellZeroState/DropoutWrapperZeroState/Ba
sicLSTMCellZeroState/zeros_1, rnn/rnn/multi_rnn_cell/cell_0/cell_0/basic_lstm_ce
ll/basic_lstm_cell/concat/axis)]]

Когда я печатаю форму обучающих данных, это (128,28,28). Меня смущает, почему тестовые данные приводят к ошибке, потому что и тренировочные данные, и тестовые данные отформатированы одинаково, то есть (?, n_chunks, chunk_size). Заранее спасибо.


person Saraj Muhammed    schedule 03.08.2017    source источник


Ответы (1)


Проблема в том, что вы всегда создаете начальное состояние с формой, установленной на размер пакета обучения, а не на размер пакета eval.

Это виновная строка:

    initial_state = state = stacked_lstm.zero_state(batch_size, tf.float32)   
person Alexandre Passos    schedule 15.11.2018