TensorFlow 2.0: как обновить тензоры?

В TensorFlow 1.x для обновления тензора я бы использовал tf.scatter_update, чтобы обновлять только соответствующую часть тензора.

Как мы можем сделать то же самое в TF 2.0?


person gpernelle    schedule 12.04.2019    source источник


Ответы (1)


Вы можете использовать tf.tensor_scatter_nd_update():

import tensorflow as tf
import numpy as np 

tensor = tf.convert_to_tensor(np.ones((2, 2)), dtype=tf.float32)
indices = tf.constant([[0, 0]])
updates = tf.constant([0.0])

tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
# array([[0., 1.],
#        [1., 1.]], dtype=float32)
person Vlad    schedule 12.04.2019
comment
На самом деле это не то, что мне нужно, это создает новый объект. Есть ли способ использовать ссылку на тот же объект? Если я хочу обновить некоторые переменные памяти, такие как буфер воспроизведения, tensorflow возвращает ошибку: операция за пределами кода построения функции передает тензор Graph. - person gpernelle; 15.07.2019
comment
Я нашел ответ на свой вопрос: tensor.assign (tf.tensor_scatter_nd_update (.... работает. - person gpernelle; 15.07.2019
comment
@gpernelle В вашем вопросе вы действительно хотели обновить переменную, а не тензорный объект (это то, что tf.scatter_update сделал в 1.x). Таким образом, ответ Влада верен при обновлении тензоров. Ваш ответ верен для обновления переменных, хотя я не знаю, почему эта функция исчезла в 2.x. - person RVH; 20.04.2020