К чему относится скрытое состояние источника в механизме внимания?

Веса внимания рассчитываются как:

введите здесь описание изображения

Я хочу знать, что означает h_s.

В коде тензорного потока кодировщик RNN возвращает кортеж:

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(...)

Как мне кажется, h_s должно быть encoder_state, но github/nmt дает другой ответ?

# attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, attention_states,
    memory_sequence_length=source_sequence_length)

Я неправильно понял код? Или h_s на самом деле означает encoder_outputs?


person imhuay    schedule 23.01.2018    source источник


Ответы (1)


Формула, вероятно, взята из этот пост, поэтому Я буду использовать картинку NN из того же поста:

нн

Здесь h-bar(s) — это все синие скрытые состояния из кодировщика (последний слой), а h(t) — текущее красное скрытое состояние из декодера (также последний слой). . На картинке t=0 вы можете увидеть, какие блоки подключены к весам внимания с помощью пунктирных стрелок. Функция score обычно одна из таких:

формула


Механизм внимания Tensorflow соответствует этой картине. Теоретически вывод ячейки является в большинстве случаев ее скрытым состоянием (одним исключением является ячейка LSTM, в которой вывод является краткосрочной частью состояния, и даже в этом случае вывод лучше подходит для механизм внимания). На практике encoder_state тензорного потока отличается от encoder_outputs, когда вход дополняется нулями: состояние распространяется из предыдущего состояния ячейки, а выход равен нулю. Очевидно, вы не хотите обращать внимание на конечные нули, поэтому имеет смысл иметь h-bar(s) для этих ячеек.

Итак, encoder_outputs — это именно те стрелки, которые идут от синих блоков вверх. Позже в коде attention_mechanism подключается к каждому decoder_cell, так что его выход проходит через вектор контекста к желтому блоку на картинке.

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell, attention_mechanism,
    attention_layer_size=num_units)
person Maxim    schedule 23.01.2018