Добавить пользовательскую функцию потери в Torch

Каковы необходимые шаги для реализации пользовательской функции потерь в Torch?

Похоже, вам нужно написать реализацию для updateOutput и updateGradInput.

В том, что все? Итак, вы в основном создаете новый класс:

local CustomCriterion, parent =   torch.class('CustomCriterion','nn.Criterion')

и реализовать следующие две функции:

function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)

Это правильно, или нужно еще что-то сделать?

Кроме того, для предоставленных критериев эти функции реализованы на C, но я полагаю, реализация Lua также будет работать, хотя, возможно, немного медленнее?


person DrMad    schedule 28.05.2016    source источник


Ответы (1)


Я реализовал функции формы (в псевдокоде)

--assuming input is partitioned in input_a,input_b
--         target is accordingly partitionend in target_a, target_b  
f(input)=MSE(input_a,target_a)+ custom_sutff(input_b,target_b)

несколько раз именно так, как вы описываете. Итак, насколько мне известно, я думаю, что ответ на оба ваших вопроса положительный.

В основном nn/MSECriterion.lua и это, похоже, подтверждает это.

person Ash    schedule 26.06.2017