Я пытаюсь создать некоторые функции для работы со следующим типом. В следующем коде используются одиночки и ограничения на GHC-8.4.1:
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
import Data.Constraint ((:-))
import Data.Singletons (sing)
import Data.Singletons.Prelude (Sing(SEQ, SGT, SLT), (%+), sCompare)
import Data.Singletons.Prelude.Num (PNum((+)))
import Data.Singletons.TypeLits (SNat)
import GHC.TypeLits (CmpNat, Nat)
data Foo where
Foo
:: forall (index :: Nat) (len :: Nat).
(CmpNat index len ~ 'LT)
=> SNat index
-> SNat len
-> Foo
Это GADT, который содержит длину и индекс. Индекс гарантированно меньше длины.
Достаточно просто написать функцию для создания Foo
:
createFoo :: Foo
createFoo = Foo (sing :: SNat 0) (sing :: SNat 1)
Однако у меня возникли проблемы с написанием функции, которая увеличивает len
, сохраняя при этом index
одинаковым:
incrementLength :: Foo -> Foo
incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))
Это не удается со следующей ошибкой:
file.hs:34:34: error:
• Could not deduce: CmpNat index (len GHC.TypeNats.+ 1) ~ 'LT
arising from a use of ‘Foo’
from the context: CmpNat index len ~ 'LT
bound by a pattern with constructor:
Foo :: forall (index :: Nat) (len :: Nat).
(CmpNat index len ~ 'LT) =>
SNat index -> SNat len -> Foo,
in an equation for ‘incrementLength’
at what5.hs:34:17-29
• In the expression: Foo index (len %+ (sing :: SNat 1))
In an equation for ‘incrementLength’:
incrementLength (Foo index len)
= Foo index (len %+ (sing :: SNat 1))
• Relevant bindings include
len :: SNat len (bound at what5.hs:34:27)
index :: SNat index (bound at what5.hs:34:21)
|
34 | incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Это имеет смысл, поскольку компилятор знает, что CmpNat index len ~ 'LT
(из определения Foo), но не знает, что CmpNat index (len + 1) ~ 'LT
.
Есть ли способ заставить что-то подобное работать?
Можно использовать sCompare
, чтобы сделать что-то вроде этого:
incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
case sCompare index (len %+ (sing :: SNat 1)) of
SLT -> Foo index (len %+ (sing :: SNat 1))
SEQ -> error "not eq"
SGT -> error "not gt"
Однако, к сожалению, мне приходится писать случаи для SEQ
и SGT
, когда я знаю, что они никогда не будут совпадать.
Кроме того, я подумал, что возможно создать тип, подобный следующему:
plusOneLTProof :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof = undefined
Однако это дает ошибку, подобную следующей:
file.hs:40:8: error:
• Couldn't match type ‘CmpNat n0 m0’ with ‘CmpNat n m’
Expected type: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
Actual type: (CmpNat n0 m0 ~ 'LT) :- (CmpNat n0 (m0 + 1) ~ 'LT)
NB: ‘CmpNat’ is a non-injective type family
The type variables ‘n0’, ‘m0’ are ambiguous
• In the ambiguity check for ‘bar’
To defer the ambiguity check to use sites, enable AllowAmbiguousTypes
In the type signature:
bar :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
|
40 | bar :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Думаю, это имеет смысл, поскольку CmpNat не является инъективным. Однако я знаю, что это утверждение верно, поэтому я хотел бы иметь возможность написать эту функцию.
Я хотел бы получить ответ на следующие два вопроса:
Есть ли способ написать
incrementLength
, где вам нужно будет сопоставить толькоSLT
? Я не против изменить определениеFoo
, чтобы упростить задачу.Есть ли способ написать
plusOneLTProof
или хотя бы что-то подобное?
Обновление. В итоге я воспользовался предложением Ли-яо Ся написать plusOneLTProof
и incrementLength
следующим образом:
incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
case plusOneLTProof index len of
Sub Dict -> Foo index (len %+ (sing :: SNat 1))
plusOneLTProof :: forall n m. SNat n -> SNat m -> (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof SNat SNat = Sub axiom
where
axiom :: CmpNat n m ~ 'LT => Dict (CmpNat n (m + 1) ~ 'LT)
axiom = unsafeCoerce (Dict :: Dict (a ~ a))
Это требует, чтобы вы передали два SNat
в plusOneLTProof
, но не требует AllowAmbiguousTypes
.