diff options
author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 00:20:47 +0200 |
---|---|---|
committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 00:26:59 +0200 |
commit | 713d76559c42129afb24843af4386d18f1827727 (patch) | |
tree | 476c00426dd778f1adaba50f069cd949a2992c34 /src/Data/Array/Nested | |
parent | 0d6f77762d1697ea41b190e550f94e459079e3d1 (diff) |
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 5 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 59 | ||||
-rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 162 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 11 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 2 |
9 files changed, 174 insertions, 73 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d5e6008..dd26f16 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -26,15 +26,14 @@ import Control.Category import Data.Proxy import Data.Type.Equality -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs deleted file mode 100644 index b1589e0..0000000 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemTakeLenMapJust PNil _ = Refl -lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl -lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemDropLenMapJust PNil _ = Refl -lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl -lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" - -lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemIndexMapJust SZ (_ :$$ _) = Refl -lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemIndexMapJust i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemIndexMapJust _ ZSS = error "Index of empty" - -lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemPermuteMapJust PNil _ = Refl -lemPermuteMapJust (i `PCons` is) sh - | Refl <- lemPermuteMapJust is sh - , Refl <- lemIndexMapJust i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs new file mode 100644 index 0000000..a3b20c6 --- /dev/null +++ b/src/Data/Array/Nested/Lemmas.hs @@ -0,0 +1,162 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + + +-- * Basic Lemmas (they don't mention shape types and don't require typing plugins) + +-- ** Nat + +lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerceRefl + +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +-- ** Append + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +-- ** Simple type families + +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp sn _ _ = go sn + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS (n :: SNat n'm1)) + | Refl <- lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (lemReplicateSucc @a @(n'm1 + m)) + +lemDropLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerceRefl + +lemTakeLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerceRefl + +lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l +lemInitApp _ _ = unsafeCoerceRefl + +lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x +lemLastApp _ _ = unsafeCoerceRefl + + +-- ** KnownNat + +lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +lemKnownNatSucc = Dict + +lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + + +-- * Complex lemmas (they mention shape types) + +-- ** Known shapes + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54bd5f2..7a86e4d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -42,7 +42,7 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Permutation import Data.Array.Nested.Types import Data.Array.XArray (XArray(..)) diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e5c51ef..973ec0e 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -29,7 +29,7 @@ import Foreign.Storable (Storable) import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Permutation import Data.Array.Nested.Types import Data.Array.XArray (XArray(..)) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index f50f671..beb5b0e 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -34,7 +34,7 @@ import GHC.TypeLits import Data.Foldable (toList) #endif -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Types import Data.Array.XArray (XArray(..)) import Data.Array.Nested.Mixed diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c0c4f17..75a1e5b 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -39,7 +39,7 @@ import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Types diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 7e38aee..01982a8 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -29,18 +29,17 @@ import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 529ac21..fa84efe 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -32,7 +32,7 @@ import GHC.TypeLits import Data.Array.Nested.Types import Data.Array.XArray (XArray) -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape |