Каковы необходимые шаги для реализации пользовательской функции потерь в Torch?
Похоже, вам нужно написать реализацию для updateOutput и updateGradInput.
В том, что все? Итак, вы в основном создаете новый класс:
local CustomCriterion, parent = torch.class('CustomCriterion','nn.Criterion')
и реализовать следующие две функции:
function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)
Это правильно, или нужно еще что-то сделать?
Кроме того, для предоставленных критериев эти функции реализованы на C, но я полагаю, реализация Lua также будет работать, хотя, возможно, немного медленнее?