diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-03-16 14:36:36 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-03-16 14:36:36 +0100 |
| commit | fc2a1370d67f12b50e3e5750d17aefd33bc3d8a3 (patch) | |
| tree | 5a3bb2d84adcd6e9bca74034d884837996467ea1 /src/Data | |
| parent | 8409fa81c7b31bf8ace0b1f219ba6a1a7cbdf2de (diff) | |
Fill and clean up *TakeIx and *DropIx functionsmvecsReplicate
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 7 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 11 |
4 files changed, 31 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 4dd350a..b0276c9 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -610,13 +610,19 @@ shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (listhDrop @i @()) +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh = coerce (listhDrop @i @i) + +{-# INLINEABLE shxTakeIx #-} +shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i +shxTakeIx _ (IxX ZX) _ = ZSX +shxTakeIx proxy (IxX (_ ::% long)) short = case short of i :$% short' -> i :$% shxTakeIx proxy (IxX long) short' + +{-# INLINEABLE shxDropIx #-} shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx ZIX long = long shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' -shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listhDrop @i @i) - shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i shxInit = coerce (listhInit @i) diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index e44ab64..c5cdf6b 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -422,6 +422,16 @@ shrTail | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce (shxTail @_ @_ @i) +{-# INLINEABLE shrTakeIx #-} +shrTakeIx :: forall n n' i j. Proxy n' -> IxR n j -> ShR (n + n') i -> ShR n i +shrTakeIx _ ZIR _ = ZSR +shrTakeIx p (_ :.: idx) sh = case sh of n :$: sh' -> n :$: shrTakeIx p idx sh' + +{-# INLINEABLE shrDropIx #-} +shrDropIx :: forall n n' i j. IxR n j -> ShR (n + n') i -> ShR n' i +shrDropIx ZIR long = long +shrDropIx (_ :.: short) long = case long of _ :$: long' -> shrDropIx short long' + shrInit :: forall n i. ShR (n + 1) i -> ShR n i shrInit | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 08711b6..d57106a 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -56,16 +56,11 @@ ssize = shsSize . sshape sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) -{-# INLINEABLE shsTakeIx #-} -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - {-# INLINEABLE sindexPartial #-} sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) idx (sshape sarr)) (Proxy @sh2))) arr) (ixxFromIxS idx)) -- | __WARNING__: All values returned from the function must have equal shape. diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index f98c860..ef91c7b 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -30,6 +30,7 @@ import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Constraint, Type) +import Data.Proxy import Data.Type.Equality import GHC.Exts (build, withDict) import GHC.IsList (IsList) @@ -394,6 +395,16 @@ shsHead (ShS shx) = case shxHead shx of shsTail :: forall n sh. ShS (n : sh) -> ShS sh shsTail = coerce (shxTail @_ @_ @Int) +{-# INLINEABLE shsTakeIx #-} +shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh +shsTakeIx _ ZIS _ = ZSS +shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh' + +{-# INLINEABLE shsDropIx #-} +shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh' +shsDropIx ZIS long = long +shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long' + shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) shsInit = gcastWith (unsafeCoerceRefl |
