Pytorch: эффективная с памятью взвешенная сумма с весами, распределенными по каналам

Входы:

1) I = Тензор dim (N, C, X) (Вход)

2) W = Тензор dim (N, X, Y) (Вес)

Выход:

1) O = Тензор dim (N, C, Y) (Выход)

Я хочу вычислить:

I = I.view(N, C, X, 1)
W = W.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
return O

без увеличения объема памяти N * C * X * Y.

В основном я хочу рассчитать взвешенную сумму карты функций, в которой веса одинаковы по размеру канала, без дополнительных затрат памяти на канал.


Может быть, я мог бы использовать

from itertools import product

O = torch.zeros(N, C, Y)
for n, x, y in product(range(N), range(X), range(Y)):
    O[n, :, y] += I[n, :, x]*W[n, x, y]
return O

но это будет медленнее (без широковещательной передачи), и я не уверен, сколько накладных расходов памяти возникнет при сохранении переменных для обратного прохода.


person Soham    schedule 03.05.2020    source источник


Ответы (1)


Вы можете использовать torch.bmm (https://pytorch.org/docs/stable/torch.html#torch.bmm). Просто сделай torch.bmm(I,W)

Чтобы проверить результаты:

import torch
N, C, X, Y= 100, 10, 9, 8 

i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
o = torch.bmm(i,w)

# desired result code
I = i.view(N, C, X, 1)
W = w.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)

print(torch.allclose(O,o)) # should output True if outputs are same.

РЕДАКТИРОВАТЬ: В идеале я бы предположил, что использование внутреннего умножения матриц pytorch эффективно. Однако вы также можете измерить использование памяти с помощью tracemalloc (по крайней мере, на ЦП). См. https://discuss.pytorch.org/t/measuring-peak-memory-usage-tracemalloc-for-pytorch/34067 для графического процессора.

import torch
import tracemalloc
tracemalloc.start()
N, C, X, Y= 100, 10, 9, 8 

i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
o = torch.bmm(i,w)
# output is a tuple indicating current memory and peak memory
print(tracemalloc.get_traced_memory())  

Вы можете сделать то же самое с другим кодом и убедиться, что реализация bmm действительно эффективна.

import torch
import tracemalloc
tracemalloc.start()
N, C, X, Y= 100, 10, 9, 8 

i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)

I = i.view(N, C, X, 1)
W = w.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
# output is a tuple indicating current memory and peak memory
print(tracemalloc.get_traced_memory())  

person Umang Gupta    schedule 03.05.2020
comment
Большое спасибо! А как проверить потребление памяти? Думаю, справедливо предположить, что это оптимальный способ сделать это. - person Soham; 04.05.2020