Я знаю, что есть похожая тема в LSTM, за которой следует средний пул, но это речь идет о Keras, и я работаю в чистом TensorFlow.
У меня есть сеть LSTM, в которой повторение обрабатывается:
outputs, final_state = tf.nn.dynamic_rnn(cell,
embed,
sequence_length=seq_lengths,
initial_state=initial_state)
где я передаю правильную длину последовательности для каждого образца (заполнение нулями). В любом случае выходные данные содержат нерелевантные выходные данные, поскольку некоторые выборки производят более длинные выходные данные, чем другие, в зависимости от длины последовательности.
Прямо сейчас я извлекаю последний соответствующий вывод с помощью следующего метода:
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)
где я передаю sequence_length - 1
в качестве индексов. Что касается последней темы, я хотел бы выбрать все релевантные результаты, за которыми следует среднее объединение, а не только последний.
Теперь я попытался передать вложенные списки как индексы в extract_axis_1
, но tf.stack
не принимает это.
Любые направления решения для этого?