Как сделать автоматическую дифференциацию сложных типов данных?

Учитывая очень простое определение матрицы на основе вектора:

import Numeric.AD
import qualified Data.Vector as V

newtype Mat a = Mat { unMat :: V.Vector a }

scale' f = Mat . V.map (*f) . unMat
add' a b = Mat $ V.zipWith (+) (unMat a) (unMat b)
sub' a b = Mat $ V.zipWith (-) (unMat a) (unMat b)
mul' a b = Mat $ V.zipWith (*) (unMat a) (unMat b)
pow' a e = Mat $ V.map (^e) (unMat a)

sumElems' :: Num a => Mat a -> a
sumElems' = V.sum . unMat

(для демонстрационных целей... я использую hmatrix, но думал, что проблема как-то там)

И функция ошибки (eq3):

eq1' :: Num a => [a] -> [Mat a] -> Mat a
eq1' as φs = foldl1 add' $ zipWith scale' as φs

eq3' :: Num a => Mat a -> [a] -> [Mat a] -> a
eq3' img as φs = negate $ sumElems' (errImg `pow'` (2::Int))
  where errImg = img `sub'` (eq1' as φs)

Почему компилятор не может вывести правильные типы в этом?

diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m φs as0 = gradientDescent go as0
  where go xs = eq3' m xs φs

Точное сообщение об ошибке:

src/Stuff.hs:59:37:
    Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
    from the context (Fractional a, Ord a)
      bound by the type signature for
                 diffTest :: (Fractional a, Ord a) =>
                             Mat a -> [Mat a] -> [a] -> [[a]]
      at src/Stuff.hs:58:13-69
    or from (reflection-1.5.1.2:Data.Reflection.Reifies
               s Numeric.AD.Internal.Reverse.Tape)
      bound by a type expected by the context:
                 reflection-1.5.1.2:Data.Reflection.Reifies
                   s Numeric.AD.Internal.Reverse.Tape =>
                 [Numeric.AD.Internal.Reverse.Reverse s a]
                 -> Numeric.AD.Internal.Reverse.Reverse s a
      at src/Stuff.hs:59:21-42
      ‘a’ is a rigid type variable bound by
          the type signature for
            diffTest :: (Fractional a, Ord a) =>
                        Mat a -> [Mat a] -> [a] -> [[a]]
          at src//Stuff.hs:58:13
    Expected type: [Numeric.AD.Internal.Reverse.Reverse s a]
                   -> Numeric.AD.Internal.Reverse.Reverse s a
      Actual type: [a] -> a
    Relevant bindings include
      go :: [a] -> a (bound at src/Stuff.hs:60:9)
      as0 :: [a] (bound at src/Stuff.hs:59:15)
      φs :: [Mat a] (bound at src/Stuff.hs:59:12)
      m :: Mat a (bound at src/Stuff.hs:59:10)
      diffTest :: Mat a -> [Mat a] -> [a] -> [[a]]
        (bound at src/Stuff.hs:59:1)
    In the first argument of ‘gradientDescent’, namely ‘go’
    In the expression: gradientDescent go as0

person fho    schedule 01.04.2015    source источник
comment
Я не уверен, действительно ли я уже задавал этот вопрос вчера. Я думал, что сделал ... но это не появляется в моих последних вопросах ...   -  person fho    schedule 01.04.2015


Ответы (1)


Функция gradientDescent из ad имеет тип

gradientDescent :: (Traversable f, Fractional a, Ord a) =>
                   (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) ->
                   f a -> [f a]

Его первый аргумент требует функции типа f r -> r, где r равно forall s. (Reverse s a). go имеет тип [a] -> a, где a — это тип, связанный с сигнатурой diffTest. Эти a одинаковы, но Reverse s a не то же самое, что a.

Тип Reverse имеет экземпляры для ряд классов типов, которые могли бы позволить нам преобразовать a в Reverse s a или обратно. Наиболее очевидным является Fractional a => Fractional (Reverse s a), который позволит нам преобразовать as в Reverse s as с помощью realToFrac.

Для этого нам потребуется сопоставить функцию a -> b с Mat a, чтобы получить Mat b. Самый простой способ сделать это — получить экземпляр Functor для Mat.

{-# LANGUAGE DeriveFunctor #-}

newtype Mat a = Mat { unMat :: V.Vector a }
    deriving Functor

Мы можем преобразовать m и fs в любое Fractional a' => Mat a' с fmap realToFrac.

diffTest m fs as0 = gradientDescent go as0
  where go xs = eq3' (fmap realToFrac m) xs (fmap (fmap realToFrac) fs)

Но есть лучший способ спрятаться в рекламном пакете. Reverse s a универсален для всех s, но a является тем же a, что и привязанный в сигнатуре типа для diffTest. Нам действительно нужна только функция a -> (forall s. Reverse s a). Это функция auto из класса Mode. , для которого Reverse s a имеет экземпляр. auto имеет немного странный тип Mode t => Scalar t -> t, но type Scalar (Reverse s a) = a. Специализированный для Reverse auto имеет тип

auto :: (Reifies s Tape, Num a) => a -> Reverse s a

Это позволяет нам конвертировать наши Mat as в Mat (Reverse s a)s, не возясь с преобразованиями в Rational и обратно.

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m fs as0 = gradientDescent go as0
  where
    go :: forall t. (Scalar t ~ a, Mode t) => [t] -> t
    go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)
person Cirdec    schedule 01.04.2015
comment
Отличная запись. Я просто хотел пожаловаться, что это должно быть очевидно из документации ad ... просто чтобы найти это прямо там на github. Теперь я чувствую себя глупо :) - person fho; 01.04.2015
comment
что делает ~ в подписи go ? - person ocramz; 06.08.2015
comment
~ - равенство типов. В нем говорится, что Scalar из t должны быть того же типа, что и a в исходных Mat a. - person Cirdec; 06.08.2015