работа с доказательствами с использованием CmpNat и синглтонов в Haskell

Я пытаюсь создать некоторые функции для работы со следующим типом. В следующем коде используются одиночки и ограничения на GHC-8.4.1:

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}

import Data.Constraint ((:-))
import Data.Singletons (sing)
import Data.Singletons.Prelude (Sing(SEQ, SGT, SLT), (%+), sCompare)
import Data.Singletons.Prelude.Num (PNum((+)))
import Data.Singletons.TypeLits (SNat)
import GHC.TypeLits (CmpNat, Nat)

data Foo where
  Foo
    :: forall (index :: Nat) (len :: Nat).
       (CmpNat index len ~ 'LT)
    => SNat index
    -> SNat len
    -> Foo

Это GADT, который содержит длину и индекс. Индекс гарантированно меньше длины.

Достаточно просто написать функцию для создания Foo:

createFoo :: Foo
createFoo = Foo (sing :: SNat 0) (sing :: SNat 1)

Однако у меня возникли проблемы с написанием функции, которая увеличивает len, сохраняя при этом index одинаковым:

incrementLength :: Foo -> Foo
incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))

Это не удается со следующей ошибкой:

file.hs:34:34: error:
    • Could not deduce: CmpNat index (len GHC.TypeNats.+ 1) ~ 'LT
        arising from a use of ‘Foo’
      from the context: CmpNat index len ~ 'LT
        bound by a pattern with constructor:
                   Foo :: forall (index :: Nat) (len :: Nat).
                          (CmpNat index len ~ 'LT) =>
                          SNat index -> SNat len -> Foo,
                 in an equation for ‘incrementLength’
        at what5.hs:34:17-29
    • In the expression: Foo index (len %+ (sing :: SNat 1))
      In an equation for ‘incrementLength’:
          incrementLength (Foo index len)
            = Foo index (len %+ (sing :: SNat 1))
    • Relevant bindings include
        len :: SNat len (bound at what5.hs:34:27)
        index :: SNat index (bound at what5.hs:34:21)
   |
34 | incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))
   |                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Это имеет смысл, поскольку компилятор знает, что CmpNat index len ~ 'LT (из определения Foo), но не знает, что CmpNat index (len + 1) ~ 'LT.

Есть ли способ заставить что-то подобное работать?

Можно использовать sCompare, чтобы сделать что-то вроде этого:

incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
  case sCompare index (len %+ (sing :: SNat 1)) of
    SLT -> Foo index (len %+ (sing :: SNat 1))
    SEQ -> error "not eq"
    SGT -> error "not gt"

Однако, к сожалению, мне приходится писать случаи для SEQ и SGT, когда я знаю, что они никогда не будут совпадать.

Кроме того, я подумал, что возможно создать тип, подобный следующему:

plusOneLTProof :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof = undefined

Однако это дает ошибку, подобную следующей:

file.hs:40:8: error:
    • Couldn't match type ‘CmpNat n0 m0’ with ‘CmpNat n m’
      Expected type: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
        Actual type: (CmpNat n0 m0 ~ 'LT) :- (CmpNat n0 (m0 + 1) ~ 'LT)
      NB: ‘CmpNat’ is a non-injective type family
      The type variables ‘n0’, ‘m0’ are ambiguous
    • In the ambiguity check for ‘bar’
      To defer the ambiguity check to use sites, enable AllowAmbiguousTypes
      In the type signature:
        bar :: (CmpNat n m ~  'LT) :- (CmpNat n (m + 1) ~  'LT)
   |
40 | bar :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
   |        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Думаю, это имеет смысл, поскольку CmpNat не является инъективным. Однако я знаю, что это утверждение верно, поэтому я хотел бы иметь возможность написать эту функцию.


Я хотел бы получить ответ на следующие два вопроса:

  1. Есть ли способ написать incrementLength, где вам нужно будет сопоставить только SLT? Я не против изменить определение Foo, чтобы упростить задачу.

  2. Есть ли способ написать plusOneLTProof или хотя бы что-то подобное?


Обновление. В итоге я воспользовался предложением Ли-яо Ся написать plusOneLTProof и incrementLength следующим образом:

incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
  case plusOneLTProof index len of
    Sub Dict -> Foo index (len %+ (sing :: SNat 1))

plusOneLTProof :: forall n m. SNat n -> SNat m -> (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof SNat SNat = Sub axiom
  where
    axiom :: CmpNat n m ~ 'LT => Dict (CmpNat n (m + 1) ~ 'LT)
    axiom = unsafeCoerce (Dict :: Dict (a ~ a))

Это требует, чтобы вы передали два SNat в plusOneLTProof, но не требует AllowAmbiguousTypes.


person illabout    schedule 05.04.2018    source источник


Ответы (1)


Компилятор отклоняет plusOneLTProof, поскольку его тип неоднозначен. Мы можем отключить это ограничение с помощью расширения AllowAmbiguousTypes. Я бы порекомендовал использовать это вместе с ExplicitForall (что подразумевается ScopedTypeVariables, которое нам в любом случае понадобится, или RankNTypes). Это для определения. Определение неоднозначного типа можно использовать с TypeApplications.

Однако GHC по-прежнему не может рассуждать о натуральных числах, поэтому мы не можем определить plusOneLTProof = Sub Dict, тем более incrementLength, небезопасно.

Но мы все еще можем создать доказательство из воздуха с помощью unsafeCoerce. Именно так модуль Data.Constraint.Nat в ограничениях реализовано; к сожалению, в настоящее время он не содержит никаких фактов о CmpNat. Приведение работает, потому что в равенствах типов нет содержимого времени выполнения. Даже если значения времени выполнения выглядят нормально, утверждение бессвязных фактов может привести к тому, что программы будут работать неправильно.

plusOneLTProof :: forall n m. (CmpNat n m ~ 'LT) :- (CmpNat n (m+1) ~ 'LT)
plusOneLTProof = Sub axiom
  where
    axiom :: (CmpNat n m ~ 'LT) => Dict (CmpNat n (m+1) ~ 'LT)
    axiom = unsafeCoerce (Dict :: Dict (a ~ a))

Чтобы использовать доказательство, мы специализируем его (с помощью TypeApplications) и сопоставляем его с образцом, чтобы ввести RHS в контекст, проверяя, что LHS выполняется.

incrementLength :: Foo -> Foo
incrementLength (Foo (n :: SNat n) (m :: SNat m)) =
  case plusOneLTProof @n @m of
    Sub Dict -> Foo n (m %+ (sing :: SNat 1))
person Li-yao Xia    schedule 05.04.2018