aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs11
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs86
-rw-r--r--src/Data/Array/Mixed/Permutation.hs17
-rw-r--r--src/Data/Array/Mixed/Shape.hs17
-rw-r--r--src/Data/Array/Mixed/Types.hs48
5 files changed, 138 insertions, 41 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index cf6820b..bb3ee4a 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
@@ -9,14 +10,14 @@
module Data.Array.Mixed.Internal.Arith where
import Control.Monad (forM, guard)
-import qualified Data.Array.Internal as OI
-import qualified Data.Array.Internal.RankedG as RG
-import qualified Data.Array.Internal.RankedS as RS
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
import Data.Bits
import Data.Int
import Data.List (sort)
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable (Storable)
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
index d0e8d24..ff2e45c 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
@@ -11,10 +12,45 @@ import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
+import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
+-- * Reasoning helpers
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+
+-- * Lemmas
+
+-- ** Nat
+
+lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
+lemLeqSuccSucc _ _ = unsafeCoerceRefl
+
+lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
+lemLeqPlus _ _ _ = Refl
+
+
+-- ** Append
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerceRefl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerceRefl
+
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+
+-- ** Rank
+
lemRankApp :: forall sh1 sh2.
StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
@@ -38,10 +74,58 @@ lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
-lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate SZ = Refl
+lemRankReplicate (SS (n :: SNat nm1))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+ , Refl <- lemRankReplicate n
+ = Refl
+
+
+-- ** Various type families
+
+lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp sn _ _ = go sn
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- lemReplicateSucc @a @n'm1
+ , Refl <- go n
+ = sym (lemReplicateSucc @a @(n'm1 + m))
+
+lemDropLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
+lemDropLenApp _ _ _ = unsafeCoerceRefl
+
+lemTakeLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
+lemTakeLenApp _ _ _ = unsafeCoerceRefl
+
+
+-- ** KnownNat
+
+lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
+lemKnownNatSucc = Dict
+
+lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+
+
+-- ** Known shapes
+
+lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
+lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
+
+lemKnownShX :: StaticShX sh -> Dict KnownShX sh
+lemKnownShX ZKX = Dict
+lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
+lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index 1df0ec7..83a5ee4 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -25,7 +26,7 @@ import Data.Type.Equality
import Data.Type.Ord
import GHC.TypeError
import GHC.TypeLits
-import qualified GHC.TypeNats as TN
+import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
@@ -112,14 +113,14 @@ listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
listxPermute PNil _ = ZX
listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+ listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
-listxIndex _ _ SZ (n ::% _) rest = n ::% rest
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
+listxIndex _ _ SZ (n ::% _) = n
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh rest
-listxIndex _ _ _ ZX _ = error "Index into empty shape"
+ = listxIndex p pT i sh
+listxIndex _ _ _ ZX = error "Index into empty shape"
listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
@@ -136,7 +137,7 @@ ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index 363b772..a13a176 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -22,7 +23,9 @@
module Data.Array.Mixed.Shape where
import Control.DeepSeq (NFData(..))
-import qualified Data.Foldable as Foldable
+import Data.Bifunctor (first)
+import Data.Coerce
+import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Kind (Type, Constraint)
import Data.Monoid (Sum(..))
@@ -30,12 +33,10 @@ import Data.Proxy
import Data.Type.Equality
import GHC.Generics (Generic)
import GHC.IsList (IsList)
-import qualified GHC.IsList as IsList
+import GHC.IsList qualified as IsList
import GHC.TypeLits
import Data.Array.Mixed.Types
-import Data.Coerce
-import Data.Bifunctor (first)
-- | The length of a type-level list. If the argument is a shape, then the
@@ -307,6 +308,10 @@ shxEnum = \sh -> go sh id []
go ZSX f = (f ZIX :)
go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+shxRank :: ShX sh f -> SNat (Rank sh)
+shxRank ZSX = SNat
+shxRank (_ :$% sh) | SNat <- shxRank sh = SNat
+
-- * Static mixed shapes
@@ -377,6 +382,10 @@ ssxFromShape :: IShX sh -> StaticShX sh
ssxFromShape ZSX = ZKX
ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
+ssxFromSNat SZ = ZKX
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+
-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index d77513f..52201df 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -22,10 +23,10 @@ module Data.Array.Mixed.Types (
-- * Type-level lists
type (++),
- lemAppNil,
- lemAppAssoc,
Replicate,
lemReplicateSucc,
+ MapJust,
+ Tail,
-- * Unsafe
unsafeCoerceRefl,
@@ -34,8 +35,8 @@ module Data.Array.Mixed.Types (
import Data.Type.Equality
import Data.Proxy
import GHC.TypeLits
-import qualified GHC.TypeNats as TN
-import qualified Unsafe.Coerce
+import GHC.TypeNats qualified as TN
+import Unsafe.Coerce qualified
-- | Evidence for the constraint @c a@.
@@ -76,6 +77,26 @@ snatMul :: SNat n -> SNat m -> SNat (n * m)
snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerceRefl
+
+type family MapJust l where
+ MapJust '[] = '[]
+ MapJust (x : xs) = Just x : MapJust xs
+
+type family Tail l where
+ Tail (_ : xs) = xs
+
+
-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
-- only typecheck for actual type equalities. One cannot, e.g. accidentally
-- write this:
@@ -89,22 +110,3 @@ snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsa
-- but would have resulted in interesting memory errors at runtime.
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl
-
-
--- | Type-level list append.
-type family l1 ++ l2 where
- '[] ++ l2 = l2
- (x : xs) ++ l2 = x : xs ++ l2
-
-lemAppNil :: l ++ '[] :~: l
-lemAppNil = unsafeCoerceRefl
-
-lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
-lemAppAssoc _ _ _ = unsafeCoerceRefl
-
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerceRefl