Проверить теорему свертки с помощью pytorch

В основном эта теорема формулируется следующим образом:

F(f*g) = F(f)xF(g)

Я знаю эту теорему, но я просто не могу воспроизвести результат с помощью pytorch.

Ниже приведен воспроизводимый код:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)

вот результат для print(FxG - F_fg)

tensor([[[[[ 0.0000e+00,  0.0000e+00],
       [ 4.1426e+02,  1.7270e+02],
       [-3.6546e+01,  4.7600e+01],
       [-1.0216e+01, -4.1198e+01],
       [-1.0216e+01, -2.0223e+00],
       [-3.6546e+01, -6.2804e+01],
       [ 4.1426e+02, -1.1427e+02]],

      ...

      [[ 4.1063e+02, -2.2347e+02],
       [-7.6294e-06,  2.2817e+01],
       [-1.9024e+01, -9.0105e+00],
       [ 7.1708e+00, -4.1027e+00],
       [-2.6739e+00, -1.1121e+01],
       [ 8.8471e+00,  7.1710e+00],
       [ 4.2528e+01,  9.7559e+01]]]]])

и вы можете видеть, что разница не всегда равна 0.

может кто-нибудь сказать мне, почему и как это сделать правильно?

Спасибо


person BarCodeReader    schedule 06.03.2020    source источник
comment
То, что они называют сверткой в ​​литературе CNN, на самом деле известно как корреляционная фильтрация в жаргоне обработки сигналов. В основном ядро ​​​​не переворачивается перед скольжением и умножением в CNN. Попробуйте F_g = torch.rfft(g_new.flip(0).flip(1), ..., что должно приблизить вас к результату. Также могут быть некоторые различия в заполнении, поскольку ДПФ предполагает, что сигналы являются периодическими (необходимыми для дискретности преобразования Фурье). Я проверю это позже.   -  person jodag    schedule 06.03.2020


Ответы (1)


Поэтому я внимательно посмотрел на то, что вы сделали до сих пор. Я определил три источника ошибок в вашем коде. Я постараюсь подробно рассмотреть каждый из них здесь.

1. Комплексная арифметика

PyTorch в настоящее время не поддерживает умножение комплексных чисел (AFAIK). Операция БПФ просто возвращает тензор с реальной и мнимой размерностью. Вместо использования оператора torch.mul или * нам нужно явно закодировать комплексное умножение.

(a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)

2. Определение свертки

Определение свертки, часто используемое в литературе CNN, на самом деле отличается от определения, используемого при обсуждении теоремы свертки. Я не буду вдаваться в подробности, но теоретическое определение переворачивает ядро ​​перед скольжением и умножение. Вместо этого операция свертки в pytorch, tensorflow, caffe и т. д. не выполняет это переворачивание.

Чтобы учесть это, мы можем просто перевернуть g (как по горизонтали, так и по вертикали) перед применением БПФ.

3. Положение якоря

Точка привязки при использовании теоремы свертки считается левым верхним углом дополненного g. Опять же, я не буду вдаваться в подробности об этом, но это то, как работает математика.


Второй и третий пункт может быть легче понять на примере. Предположим, вы использовали следующие g

[1 2 3]
[4 5 6]
[7 8 9]

вместо g_new

[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]

это должно быть на самом деле

[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]

где мы отражаем ядро ​​по вертикали и горизонтали, затем применяем круговой сдвиг, чтобы центр ядра находился в верхнем левом углу.


В итоге я переписал большую часть вашего кода и немного обобщил его. Самая сложная операция — правильно определить g_new. Я решил использовать сетку и арифметику по модулю, чтобы одновременно переворачивать и сдвигать индексы. Если что-то здесь не имеет смысла для вас, оставьте комментарий, и я постараюсь уточнить.

import torch
import torch.nn.functional as F

def conv2d_pyt(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    f_new = f.unsqueeze(0).unsqueeze(0)
    g_new = g.unsqueeze(0).unsqueeze(0)

    pad_y = (g.size(0) - 1) // 2
    pad_x = (g.size(1) - 1) // 2

    fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
    return fcg[0, 0, :, :]

def conv2d_fft(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    # in general not necessary that inputs are odd shaped but makes life easier
    assert f.size(0) % 2 == 1
    assert f.size(1) % 2 == 1
    assert g.size(0) % 2 == 1
    assert g.size(1) % 2 == 1

    size_y = f.size(0) + g.size(0) - 1
    size_x = f.size(1) + g.size(1) - 1

    f_new = torch.zeros((size_y, size_x))
    g_new = torch.zeros((size_y, size_x))

    # copy f to center
    f_pad_y = (f_new.size(0) - f.size(0)) // 2
    f_pad_x = (f_new.size(1) - f.size(1)) // 2
    f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f

    # anchor of g is 0,0 (flip g and wrap circular)
    g_center_y = g.size(0) // 2
    g_center_x = g.size(1) // 2
    g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
    g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
    g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
    g_new[g_new_y, g_new_x] = g[g_y, g_x]

    # take fft of both f and g
    F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
    F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)

    # complex multiply
    FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
    FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
    FxG = torch.stack([FxG_real, FxG_imag], dim=2)

    # inverse fft
    fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)

    # crop center before returning
    return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]


# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)

fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)

Что дает мне

Average difference: 4.6866085767760524e-07

Это очень близко к нулю. Причина, по которой мы не получаем ровно ноль, заключается просто в ошибках с плавающей запятой.

person jodag    schedule 08.03.2020
comment
Что является хорошим ресурсом для получения дополнительной информации об этом, особенно о циклическом сдвиге и заполнении ядра нулями? - person Kiran; 24.03.2021
comment
@Kiran 1) Сигнал является дискретным по частоте тогда и только тогда, когда он периодичен во времени (аналогично сигнал дискретен во времени тогда и только тогда, когда он периодичен по частоте). Поэтому ДПФ, который переходит от дискретного времени к дискретной частоте, предполагает, что сигнал является периодическим во времени. Это объясняет, почему все смены цикличны. 2) Позиция привязки обусловлена ​​соглашением об интерпретации входных данных для ДПФ как одного периода сигнала, начинающегося в момент времени t=0. Вы можете изучить это в любой книге по DSP, моя любимая — «Обработка сигналов дискретного времени» Оппенгейма. - person jodag; 13.04.2021