Входы:
1) I = Тензор dim (N, C, X) (Вход)
2) W = Тензор dim (N, X, Y) (Вес)
Выход:
1) O = Тензор dim (N, C, Y) (Выход)
Я хочу вычислить:
I = I.view(N, C, X, 1)
W = W.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
return O
без увеличения объема памяти N * C * X * Y.
В основном я хочу рассчитать взвешенную сумму карты функций, в которой веса одинаковы по размеру канала, без дополнительных затрат памяти на канал.
Может быть, я мог бы использовать
from itertools import product
O = torch.zeros(N, C, Y)
for n, x, y in product(range(N), range(X), range(Y)):
O[n, :, y] += I[n, :, x]*W[n, x, y]
return O
но это будет медленнее (без широковещательной передачи), и я не уверен, сколько накладных расходов памяти возникнет при сохранении переменных для обратного прохода.