diff options
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 5 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 59 | ||||
-rw-r--r-- | src/Data/Array/Nested/Lemmas.hs (renamed from src/Data/Array/Mixed/Lemmas.hs) | 100 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 11 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/XArray.hs | 2 |
10 files changed, 85 insertions, 102 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d5e6008..dd26f16 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -26,15 +26,14 @@ import Control.Category import Data.Proxy import Data.Type.Equality -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs deleted file mode 100644 index b1589e0..0000000 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemTakeLenMapJust PNil _ = Refl -lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl -lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemDropLenMapJust PNil _ = Refl -lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl -lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" - -lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemIndexMapJust SZ (_ :$$ _) = Refl -lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemIndexMapJust i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemIndexMapJust _ ZSS = error "Index of empty" - -lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemPermuteMapJust PNil _ = Refl -lemPermuteMapJust (i `PCons` is) sh - | Refl <- lemPermuteMapJust is sh - , Refl <- lemIndexMapJust i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e6d970c..a3b20c6 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -6,7 +6,7 @@ {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Lemmas where +module Data.Array.Nested.Lemmas where import Data.Proxy import Data.Type.Equality @@ -14,10 +14,11 @@ import GHC.TypeLits import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types --- * Lemmas +-- * Basic Lemmas (they don't mention shape types and don't require typing plugins) -- ** Nat @@ -27,7 +28,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl - -- ** Append lemAppNil :: l ++ '[] :~: l @@ -39,31 +39,7 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l lemAppLeft _ Refl = Refl - --- ** Rank - -lemRankApp :: forall sh1 sh2. - StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp ZKX _ = Refl -lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 - = lem (Proxy @(Rank sh1T)) Proxy Proxy $ - sym (lemRankApp ssh1 ssh2) - where - lem :: proxy a -> proxy b -> proxy c - -> (a + b :~: c) - -> c + 1 :~: (a + 1 + b) - lem _ _ _ Refl = Refl - -lemRankAppComm :: proxy sh1 -> proxy sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl - -lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = unsafeCoerceRefl - - --- ** Various type families +-- ** Simple type families lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a @@ -107,6 +83,8 @@ lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +-- * Complex lemmas (they mention shape types) + -- ** Known shapes lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) @@ -116,3 +94,69 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh lemKnownShX ZKX = Dict lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54bd5f2..7a86e4d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -42,7 +42,7 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Permutation import Data.Array.Nested.Types import Data.Array.XArray (XArray(..)) diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e5c51ef..973ec0e 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -29,7 +29,7 @@ import Foreign.Storable (Storable) import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Permutation import Data.Array.Nested.Types import Data.Array.XArray (XArray(..)) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index f50f671..beb5b0e 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -34,7 +34,7 @@ import GHC.TypeLits import Data.Foldable (toList) #endif -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Types import Data.Array.XArray (XArray(..)) import Data.Array.Nested.Mixed diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c0c4f17..75a1e5b 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -39,7 +39,7 @@ import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Types diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 7e38aee..01982a8 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -29,18 +29,17 @@ import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 529ac21..fa84efe 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -32,7 +32,7 @@ import GHC.TypeLits import Data.Array.Nested.Types import Data.Array.XArray (XArray) -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index dde06e3..12534a3 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -31,7 +31,7 @@ import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Permutation import Data.Array.Nested.Types import Data.Array.Nested.Mixed.Shape |