From 52c0237fbdbc3c99ee6565ba18250360a330fb8b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 17:21:21 +0200 Subject: Rerank on primitive arrays --- src/Data/Array/Mixed.hs | 10 ++++++ src/Data/Array/Nested.hs | 3 ++ src/Data/Array/Nested/Internal.hs | 68 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 80 insertions(+), 1 deletion(-) (limited to 'src/Data') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 33c0dd6..2f23903 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -320,12 +320,22 @@ listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i ixDrop = coerce (listxDrop @(Const i) @(Const i)) +shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i +shTakeSSX _ = flip go + where + go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i + go ZKX _ = ZSX + go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh + ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 51754d0..2208349 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -8,6 +8,7 @@ module Data.Array.Nested ( ShR(.., ZSR, (:$:)), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, + rrerank, rreplicate, rfromList, rfromList1, rtoList, rtoList1, rslice, rrev1, rreshape, -- ** Lifting orthotope operations to 'Ranked' arrays @@ -23,6 +24,7 @@ module Data.Array.Nested ( ShS(.., ZSS, (:$$)), KnownShS(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, + srerank, sreplicate, sfromList, sfromList1, stoList, stoList1, sslice, srev1, sreshape, -- ** Lifting orthotope operations to 'Shaped' arrays @@ -36,6 +38,7 @@ module Data.Array.Nested ( IxX(..), IIxX, KnownShX(..), StaticShX(..), mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, + mrerank, mreplicate, mfromList, mtoList, mslice, mrev1, mreshape, -- ** Conversions masXArrayPrim, mfromXArrayPrim, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 99c4a46..badb910 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -28,7 +28,6 @@ {-| TODO: -* Write `rerank` * Write `rconst :: OR.Array n a -> Ranked n a` -} @@ -900,6 +899,24 @@ mtoList = map munScalar . mtoList1 munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX +mrerankP :: forall sh1 sh2 sh a. Storable a + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive a) +mrerankP ssh sh2 f (M_Primitive sh arr) = + let sh1 = shDropSSX sh ssh + in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) + +mrerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 a) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a +mrerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) mreplicateP sh x = M_Primitive sh (X.replicate sh x) @@ -1389,6 +1406,24 @@ rtoList1 = map runScalar . rtoList runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR +rrerankP :: forall n1 n2 n a. Storable a + => SNat n -> IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive a)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive a) +rrerankP sn sh2 f (Ranked arr) + | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) + , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) + = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr) + +rrerank :: forall n1 n2 n a. (Storable a, PrimElt a) + => SNat n -> IShR n2 + -> (Ranked n1 a -> Ranked n2 a) + -> Ranked (n + n1) a -> Ranked (n + n2) a +rrerank ssh sh2 f (rtoPrimitive -> arr) = + rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr + rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) rreplicateP sh x | Dict <- lemKnownReplicate (snatFromShR sh) @@ -1438,6 +1473,12 @@ rcastToShaped (Ranked arr) targetsh , Refl <- lemRankMapJust targetsh = mcastToShaped arr targetsh +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -1619,6 +1660,25 @@ stoList1 = map sunScalar . stoList sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS +srerankP :: forall sh1 sh2 sh a. Storable a + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive a)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive a) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemCommMapJustApp sh (Proxy @sh1) + , Refl <- lemCommMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh)))) + (shCvtSX sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +srerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 a) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a +srerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x) @@ -1652,3 +1712,9 @@ stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) -- cgit v1.2.3-70-g09d2