Поэтому я внимательно посмотрел на то, что вы сделали до сих пор. Я определил три источника ошибок в вашем коде. Я постараюсь подробно рассмотреть каждый из них здесь.
1. Комплексная арифметика
PyTorch в настоящее время не поддерживает умножение комплексных чисел (AFAIK). Операция БПФ просто возвращает тензор с реальной и мнимой размерностью. Вместо использования оператора torch.mul
или *
нам нужно явно закодировать комплексное умножение.
(a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)
2. Определение свертки
Определение свертки, часто используемое в литературе CNN, на самом деле отличается от определения, используемого при обсуждении теоремы свертки. Я не буду вдаваться в подробности, но теоретическое определение переворачивает ядро перед скольжением и умножение. Вместо этого операция свертки в pytorch, tensorflow, caffe и т. д. не выполняет это переворачивание.
Чтобы учесть это, мы можем просто перевернуть g
(как по горизонтали, так и по вертикали) перед применением БПФ.
3. Положение якоря
Точка привязки при использовании теоремы свертки считается левым верхним углом дополненного g
. Опять же, я не буду вдаваться в подробности об этом, но это то, как работает математика.
Второй и третий пункт может быть легче понять на примере. Предположим, вы использовали следующие g
[1 2 3]
[4 5 6]
[7 8 9]
вместо g_new
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
это должно быть на самом деле
[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]
где мы отражаем ядро по вертикали и горизонтали, затем применяем круговой сдвиг, чтобы центр ядра находился в верхнем левом углу.
В итоге я переписал большую часть вашего кода и немного обобщил его. Самая сложная операция — правильно определить g_new
. Я решил использовать сетку и арифметику по модулю, чтобы одновременно переворачивать и сдвигать индексы. Если что-то здесь не имеет смысла для вас, оставьте комментарий, и я постараюсь уточнить.
import torch
import torch.nn.functional as F
def conv2d_pyt(f, g):
assert len(f.size()) == 2
assert len(g.size()) == 2
f_new = f.unsqueeze(0).unsqueeze(0)
g_new = g.unsqueeze(0).unsqueeze(0)
pad_y = (g.size(0) - 1) // 2
pad_x = (g.size(1) - 1) // 2
fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
return fcg[0, 0, :, :]
def conv2d_fft(f, g):
assert len(f.size()) == 2
assert len(g.size()) == 2
# in general not necessary that inputs are odd shaped but makes life easier
assert f.size(0) % 2 == 1
assert f.size(1) % 2 == 1
assert g.size(0) % 2 == 1
assert g.size(1) % 2 == 1
size_y = f.size(0) + g.size(0) - 1
size_x = f.size(1) + g.size(1) - 1
f_new = torch.zeros((size_y, size_x))
g_new = torch.zeros((size_y, size_x))
# copy f to center
f_pad_y = (f_new.size(0) - f.size(0)) // 2
f_pad_x = (f_new.size(1) - f.size(1)) // 2
f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f
# anchor of g is 0,0 (flip g and wrap circular)
g_center_y = g.size(0) // 2
g_center_x = g.size(1) // 2
g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
g_new[g_new_y, g_new_x] = g[g_y, g_x]
# take fft of both f and g
F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
# complex multiply
FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
FxG = torch.stack([FxG_real, FxG_imag], dim=2)
# inverse fft
fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)
# crop center before returning
return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]
# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)
fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)
avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()
print('Average difference:', avg_diff)
Что дает мне
Average difference: 4.6866085767760524e-07
Это очень близко к нулю. Причина, по которой мы не получаем ровно ноль, заключается просто в ошибках с плавающей запятой.
person
jodag
schedule
08.03.2020
F_g = torch.rfft(g_new.flip(0).flip(1), ...
, что должно приблизить вас к результату. Также могут быть некоторые различия в заполнении, поскольку ДПФ предполагает, что сигналы являются периодическими (необходимыми для дискретности преобразования Фурье). Я проверю это позже. - person jodag   schedule 06.03.2020