diff options
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Nested.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 61 | ||||
-rw-r--r-- | src/Data/INat.hs | 121 |
3 files changed, 40 insertions, 146 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index c7d1819..c12d8ad 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -37,9 +37,6 @@ module Data.Array.Nested ( PrimElt, Primitive(..), - -- * Inductive natural numbers - module Data.INat, - -- * Further utilities / re-exports type (++), Storable, @@ -49,5 +46,4 @@ import Prelude hiding (mappend) import Data.Array.Mixed import Data.Array.Nested.Internal -import Data.INat import Foreign.Storable diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 54b567a..222247b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -155,15 +155,26 @@ snatPred snp1 = EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) GTI -> Nothing + +-- Stupid things that the type checker should be able to figure out in-line, but can't + subst1 :: forall f a b. a :~: b -> f a :~: f b subst1 Refl = Refl subst2 :: forall f c a b. a :~: b -> f a c :~: f b c subst2 Refl = Refl +-- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that. lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +knownNatSucc :: KnownNat n => Dict KnownNat (1 + n) +knownNatSucc = Dict + + lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where @@ -947,7 +958,7 @@ type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (1 + n) i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Show i => Show (ListR n i) deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) @@ -963,7 +974,7 @@ listRToList (i ::: is) = i : listRToList is knownListR :: ListR n i -> Dict KnownNat n knownListR ZR = Dict -knownListR (_ ::: l) | Dict <- knownListR l = Dict +knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m -- | An index into a rank-typed array. type role IxR nominal representational @@ -1040,11 +1051,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX -ixCvtRX (n :.: idx) = n :.? ixCvtRX idx +ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX -shCvtRX (n :$: idx) = n :$? shCvtRX idx +shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) shapeSizeR :: IShR n -> Int shapeSizeR ZSR = 1 @@ -1084,19 +1095,19 @@ rlift f (Ranked arr) = Ranked (mlift f arr) rsumOuter1P :: forall n a. - (Storable a, Num a, KnownNat n, 1 <= n) - => Ranked n (Primitive a) -> Ranked (n - 1) (Primitive a) + (Storable a, Num a, KnownNat n) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked - . coerce @(XArray (Replicate (n - 1) 'Nothing) a) @(Mixed (Replicate (n - 1) 'Nothing) (Primitive a)) - . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate (n - 1) Nothing)) - . coerce @(Mixed (Replicate n Nothing) (Primitive a)) @(XArray (Replicate n Nothing) a) + . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) + . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) + . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a) $ arr -rsumOuter1 :: forall n a. - (Storable a, Num a, PrimElt a, KnownNat n, 1 <= n) - => Ranked n a -> Ranked (n - 1) a +rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n) + => Ranked (1 + n) a -> Ranked n a rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a @@ -1104,9 +1115,12 @@ rtranspose perm (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mtranspose perm arr) -rappend :: forall n a. (KnownNat n, Elt a, 1 <= n) - => Ranked n a -> Ranked n a -> Ranked n a -rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend +rappend :: forall n a. (KnownNat n, Elt a) + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) @@ -1125,16 +1139,19 @@ rtoVectorP = coerce mtoVectorP rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (1 + n) a +rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromList1 l | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromList1 (coerce l)) + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l)) rfromList :: Elt a => NonEmpty a -> Ranked 1 a rfromList = Ranked . mfromList1 . fmap mscalar -rtoList :: Elt a => Ranked (1 + n) a -> [Ranked n a] -rtoList (Ranked arr) = coerce (mtoList1 arr) +rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoList (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] rtoList1 = map runScalar . rtoList @@ -1154,8 +1171,10 @@ rconstant sh x = coerce fromPrimitive (rconstantP sh x) rslice :: (KnownNat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a rslice ivs = rlift $ \_ -> X.slice ivs -rrev1 :: (KnownNat n, Elt a, 1 <= n) => Ranked n a -> Ranked n a -rrev1 = rlift $ \_ -> X.rev1 +rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 = rlift $ \(Proxy @sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh') rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) => IShR n' -> Ranked n a -> Ranked n' a diff --git a/src/Data/INat.hs b/src/Data/INat.hs deleted file mode 100644 index af8f18b..0000000 --- a/src/Data/INat.hs +++ /dev/null @@ -1,121 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.INat where - -import Data.Proxy -import Data.Type.Equality ((:~:) (Refl)) -import Numeric.Natural -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - --- | An inductive peano natural number. Intended to be used at the type level. -data INat = Z | S INat - deriving (Show) - --- | Singleton for a 'INat'. -data SINat n where - SZ :: SINat Z - SS :: SINat n -> SINat (S n) -deriving instance Show (SINat n) - --- | A singleton 'SINat' corresponding to @n@. -class KnownINat n where inatSing :: SINat n -instance KnownINat Z where inatSing = SZ -instance KnownINat n => KnownINat (S n) where inatSing = SS inatSing - --- | Explicitly bidirectional pattern synonym that converts between a singleton --- 'SINat' and evidence of a 'KnownINat' constraint. Analogous to 'GHC.SNat'. -pattern SINat' :: () => KnownINat n => SINat n -pattern SINat' <- (snatKnown -> Dict) - where SINat' = inatSing - --- | A 'KnownINat' dictionary is just a singleton natural, so we can create --- evidence of 'KnownINat' given an 'SINat'. -snatKnown :: SINat n -> Dict KnownINat n -snatKnown SZ = Dict -snatKnown (SS n) | Dict <- snatKnown n = Dict - --- | Convert a 'INat' to a normal number. -fromINat :: INat -> Natural -fromINat Z = 0 -fromINat (S n) = 1 + fromINat n - --- | Convert an 'SINat' to a normal number. -fromSINat :: SINat n -> Natural -fromSINat SZ = 0 -fromSINat (SS n) = 1 + fromSINat n - --- | The value of a known inductive natural as a value-level integer. -inatVal :: forall n. KnownINat n => Proxy n -> Natural -inatVal _ = fromSINat (inatSing @n) - --- | Add two 'INat's -type family n +! m where - Z +! m = m - S n +! m = S (n +! m) - --- | Convert a 'INat' to a "GHC.TypeLits" 'G.Nat'. -type family FromINat n where - FromINat Z = 0 - FromINat (S n) = 1 + FromINat n - --- | Convert a "GHC.TypeLits" 'G.Nat' to a 'INat'. -type family ToINat (n :: Nat) where - ToINat 0 = Z - ToINat n = S (ToINat (n - 1)) - -lemInjectiveFromINat :: n :~: ToINat (FromINat n) -lemInjectiveFromINat = unsafeCoerce Refl - -lemSuccFromINat :: Proxy n -> 1 + FromINat n :~: FromINat (S n) -lemSuccFromINat _ = unsafeCoerce Refl - -lemAddFromINat :: Proxy m -> Proxy n - -> FromINat m + FromINat n :~: FromINat (m +! n) -lemAddFromINat _ = unsafeCoerce Refl - -lemInjectiveToINat :: n :~: FromINat (ToINat n) -lemInjectiveToINat = unsafeCoerce Refl - -lemSuccToINat :: Proxy n -> ToINat (1 + n) :~: S (ToINat n) -lemSuccToINat _ = unsafeCoerce Refl - -lemAddToINat :: Proxy m -> Proxy n -> ToINat (m + n) :~: ToINat m +! ToINat n -lemAddToINat _ _ = unsafeCoerce Refl - --- | If an inductive 'INat' is known, then the corresponding "GHC.TypeLits" --- 'G.Nat' is also known. -knownNatFromINat :: KnownINat n => Proxy n -> Dict KnownNat (FromINat n) -knownNatFromINat (Proxy @n) = go (SINat' @n) - where - go :: SINat m -> Dict KnownNat (FromINat m) - go SZ = Dict - go (SS n) | Dict <- go n = Dict - --- * Some type-level inductive naturals - -type I0 = Z -type I1 = S I0 -type I2 = S I1 -type I3 = S I2 -type I4 = S I3 -type I5 = S I4 -type I6 = S I5 -type I7 = S I6 -type I8 = S I7 -type I9 = S I8 |