Я прорабатываю руководство PyTorch по Определение нового автограда функции. Функция автограда, которую я хочу реализовать, представляет собой оболочку вокруг _1 _ А>. Вот что у меня есть на данный момент:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as tag
class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, \
return_indices=False, ceil_mode=False):
ctx.save_for_backward( input )
inputC = input.clone() #copy input
inputC *= inputC
output = F.max_pool1d(inputC, kernel_size, stride=stride, \
padding=padding, dilation=dilation, \
return_indices=return_indices, \
ceil_mode=ceil_mode)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = get_max_pool1d_grad_somehow(grad_output)
return 2.0*input*grad_input
Мой вопрос: как получить градиент обернутой функции? Я знаю, что, вероятно, есть и другие способы сделать это, учитывая, насколько простой пример, который я представляю, но то, что я хочу сделать, соответствует этой структуре и требует от меня реализации функции autograd
.
Изменить: после изучения этого сообщения в блоге Решил попробовать для backward
следующее:
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = output.backward(grad_output)
return 2.0*input*grad_input
с добавлением output
к сохраненным переменным. Затем я запускаю следующий код:
x = np.random.randn(1,1,5)
xT = torch.from_numpy(x)
xT.requires_grad=True
f = SquareAndMaxPool1d.apply
s = torch.sum(f(xT,2))
s.backward()
и я получаю Bus error: 10
.
Скажем, xT
равно tensor([[[ 1.69533562, -0.21779421, 2.28693953, -0.86688095, -1.01033497]]], dtype=torch.float64)
, тогда я ожидал бы обнаружить, что xT.grad
равно tensor([[[ 3.39067124, -0. , 9.14775812, -0. , -2.02066994]]], dtype=torch.float64)
после вызова s.backward()
(то есть 2*x*grad_of_max_pool
, где grad_of_max_pool
содержит tensor([[[1., 0., 2., 0., 1.]]], dtype=torch.float64)
).
Я понял, почему у меня Bus error: 10
. Похоже, что приведенный выше код приводит к рекурсивному вызову моего backward
по адресу grad_input = output.backward(grad_output)
. Поэтому мне нужно найти другой способ получить градиент max_pool1d
. Я знаю, как реализовать это на чистом Python, но результат был бы намного медленнее, чем если бы я мог обернуть код библиотеки.