умножать только определенные столбцы массива тензорного потока

В настоящее время я модифицирую функцию потерь для одной из моих нейронных сетей обнаружения объектов. У меня в основном два массива;

y_true: ярлыки предсказаний. tf тензор формы (x, y, z) y_pred: предсказанные значения. tf тензор формы (x, y, z) - измерение x - это размер пакета, измерение y - это количество предсказанных объектов в изображении, измерение z содержит быстрое кодирование классов, а также ограничивающие коробки указанных классов.

Теперь к настоящему вопросу: то, что я хочу сделать, это в основном умножить первые 5 значений z-значений в y_pred на первые 5 z-значений в y_true. Все остальные значения не должны измениться. В numpy это очень просто;

y_pred[:,:,:5] *= y_true[:,:,:5]

Мне очень сложно сделать это в тензорном потоке, поскольку я не могу присвоить значения исходному тензору, и я хочу, чтобы все остальные значения оставались такими же. Как мне это сделать в тензорном потоке?


person puffadder    schedule 05.04.2018    source источник


Ответы (1)


Начиная с версии 1.1, Tensorflow охватывает такое индексирование, подобное Numpy, см. Tensor.getitem.

import tensorflow as tf

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
    print((y_pred[:,:,:5] * y_true[:,:,:5]).eval()) 
    # [[[   1    4    9   16   25]
    #   [ 100  400  900 1600 2500]]]

ИЗМЕНИТЬ после комментария:

Теперь проблема заключается в части "* =", то есть в назначении элемента. Это непростая операция в Tensorflow. Однако в вашем случае это можно легко решить с помощью tf.concat или tf.where (tf.dynamic_partition + tf.dynamic_stitch можно использовать для более сложных случаев).

Ниже вы найдете быструю реализацию двух первых решений.

Решение с использованием Tensor.getitem и tf.concat:

import tensorflow as tf

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])

    # tf.where can't apply the condition to any axis (see doc).
    # In your case (condition on 2nd axis), we need either to manually broadcast the
    # condition tensor, or transpose the target tensors.
    # Here is a quick demonstration with the 2nd solution:

    y_pred_edit = y_pred[:,:,:5] * y_true[:,:,:5]
    y_pred_rest = y_pred[:,:,4:]

    y_pred = tf.concat((y_pred_edit, y_pred_rest), axis=2)
    print(y_pred.eval())
    # [[[ 1  4  9 16 25  6  7  8  9 10]]]

Решение с использованием tf.where:

import tensorflow as tf

def select_n_fist_indices(n, batch_size):
    """ Return a list of length batch_size with the n first elements True
        and the rest False, i.e. [*[[True] * n], *[[False] * (batch_size - n)]]. 
    """
    n_ones = tf.ones((n))
    rest_zeros = tf.zeros((batch_size - n))
    indices = tf.cast(tf.concat((n_ones, rest_zeros), axis=0), dtype=tf.bool)

    return indices

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])

    # tf.where can't apply the condition to any axis (see doc).
    # In your case (condition on 2nd axis), we need either to manually broadcast the 
    # condition tensor, or transpose the target tensors.
    # Here is a quick demonstration with the 2nd solution:
    y_pred_tranposed = tf.transpose(y_pred, [2, 0, 1])
    y_true_tranposed = tf.transpose(y_true, [2, 0, 1])

    edit_indices = select_n_fist_indices(5, tf.shape(y_pred_tranposed)[0])

    y_pred_tranposed = tf.where(condition=edit_indices, 
                                x=y_pred_tranposed * y_true_tranposed, y=y_pred_tranposed)

    # Transpose back:  
    y_pred = tf.transpose(y_pred_tranposed, [1, 2, 0])
    print(y_pred.eval())
    # [[[ 1  4  9 16 25  6  7  8  9 10]]]
person benjaminplanche    schedule 05.04.2018
comment
Хорошо, но как мне сохранить другие значения? Я все еще хочу иметь массив формы (x, y, z), это даст мне массив (x, y, 5)? Как мне вернуть (x, y, 5) в y_pred, чтобы теперь у меня были умноженные столбцы и не умноженные ограничивающие рамки? - person puffadder; 05.04.2018
comment
Да, плохо, потом заметил, что ответил только на половину вашего вопроса. Я редактировал с более полным решением. - person benjaminplanche; 05.04.2018