diff options
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 11 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 86 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 17 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 17 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 48 |
5 files changed, 138 insertions, 41 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index cf6820b..bb3ee4a 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} @@ -9,14 +10,14 @@ module Data.Array.Mixed.Internal.Arith where import Control.Monad (forM, guard) -import qualified Data.Array.Internal as OI -import qualified Data.Array.Internal.RankedG as RG -import qualified Data.Array.Internal.RankedS as RS +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS import Data.Bits import Data.Int import Data.List (sort) -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types import Foreign.Ptr import Foreign.Storable (Storable) diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index d0e8d24..ff2e45c 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} @@ -11,10 +12,45 @@ import Data.Proxy import Data.Type.Equality import GHC.TypeLits +import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +-- * Reasoning helpers + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + + +-- * Lemmas + +-- ** Nat + +lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerceRefl + +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + + +-- ** Append + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + + +-- ** Rank + lemRankApp :: forall sh1 sh2. StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 @@ -38,10 +74,58 @@ lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this -lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate SZ = Refl +lemRankReplicate (SS (n :: SNat nm1)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 + , Refl <- lemRankReplicate n + = Refl + + +-- ** Various type families + +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp sn _ _ = go sn + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS (n :: SNat n'm1)) + | Refl <- lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (lemReplicateSucc @a @(n'm1 + m)) + +lemDropLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerceRefl + +lemTakeLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerceRefl + + +-- ** KnownNat + +lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +lemKnownNatSucc = Dict + +lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh) lemKnownNatRank ZSX = Dict lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + + +-- ** Known shapes + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 1df0ec7..83a5ee4 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} @@ -25,7 +26,7 @@ import Data.Type.Equality import Data.Type.Ord import GHC.TypeError import GHC.TypeLits -import qualified GHC.TypeNats as TN +import GHC.TypeNats qualified as TN import Data.Array.Mixed.Shape import Data.Array.Mixed.Types @@ -112,14 +113,14 @@ listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f listxPermute PNil _ = ZX listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f -listxIndex _ _ SZ (n ::% _) rest = n ::% rest -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) +listxIndex _ _ SZ (n ::% _) = n +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh rest -listxIndex _ _ _ ZX _ = error "Index into empty shape" + = listxIndex p pT i sh +listxIndex _ _ _ ZX = error "Index into empty shape" listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) @@ -136,7 +137,7 @@ ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index 363b772..a13a176 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -22,7 +23,9 @@ module Data.Array.Mixed.Shape where import Control.DeepSeq (NFData(..)) -import qualified Data.Foldable as Foldable +import Data.Bifunctor (first) +import Data.Coerce +import Data.Foldable qualified as Foldable import Data.Functor.Const import Data.Kind (Type, Constraint) import Data.Monoid (Sum(..)) @@ -30,12 +33,10 @@ import Data.Proxy import Data.Type.Equality import GHC.Generics (Generic) import GHC.IsList (IsList) -import qualified GHC.IsList as IsList +import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Mixed.Types -import Data.Coerce -import Data.Bifunctor (first) -- | The length of a type-level list. If the argument is a shape, then the @@ -307,6 +308,10 @@ shxEnum = \sh -> go sh id [] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] +shxRank :: ShX sh f -> SNat (Rank sh) +shxRank ZSX = SNat +shxRank (_ :$% sh) | SNat <- shxRank sh = SNat + -- * Static mixed shapes @@ -377,6 +382,10 @@ ssxFromShape :: IShX sh -> StaticShX sh ssxFromShape ZSX = ZKX ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n + -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index d77513f..52201df 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -22,10 +23,10 @@ module Data.Array.Mixed.Types ( -- * Type-level lists type (++), - lemAppNil, - lemAppAssoc, Replicate, lemReplicateSucc, + MapJust, + Tail, -- * Unsafe unsafeCoerceRefl, @@ -34,8 +35,8 @@ module Data.Array.Mixed.Types ( import Data.Type.Equality import Data.Proxy import GHC.TypeLits -import qualified GHC.TypeNats as TN -import qualified Unsafe.Coerce +import GHC.TypeNats qualified as TN +import Unsafe.Coerce qualified -- | Evidence for the constraint @c a@. @@ -76,6 +77,26 @@ snatMul :: SNat n -> SNat m -> SNat (n * m) snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl + +type family MapJust l where + MapJust '[] = '[] + MapJust (x : xs) = Just x : MapJust xs + +type family Tail l where + Tail (_ : xs) = xs + + -- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to -- only typecheck for actual type equalities. One cannot, e.g. accidentally -- write this: @@ -89,22 +110,3 @@ snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsa -- but would have resulted in interesting memory errors at runtime. unsafeCoerceRefl :: a :~: b unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerceRefl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerceRefl - -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl |