From 86de413131773f64e1bfd71dd080eb64812a87ee Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 25 May 2024 10:45:54 +0200 Subject: replicate -> replicateScal; add proper generic replicate --- src/Data/Array/Mixed.hs | 13 ++++++-- src/Data/Array/Nested.hs | 6 ++-- src/Data/Array/Nested/Internal.hs | 65 ++++++++++++++++++++++++++------------- 3 files changed, 57 insertions(+), 27 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 748914c..d894b85 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr) unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicate sh x +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') + , Refl <- lemRankApp (staticShapeFrom sh) ssh' + = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 49923d8..968ea18 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,7 +9,7 @@ module Data.Array.Nested ( rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, rrerank, - rreplicate, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, + rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, rslice, rrev1, rreshape, riota, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, rlift2, @@ -26,7 +26,7 @@ module Data.Array.Nested ( sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, srerank, - sreplicate, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, + sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, sslice, srev1, sreshape, siota, -- ** Lifting orthotope operations to 'Shaped' arrays slift, slift2, @@ -41,7 +41,7 @@ module Data.Array.Nested ( mshape, mindex, mindexPartial, mgenerate, msumOuter1, mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar, mrerank, - mreplicate, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, + mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, mslice, mrev1, mreshape, miota, -- ** Lifting orthotope operations to 'Mixed' arrays mlift, mlift2, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index c70da54..b7308fa 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -967,12 +967,22 @@ mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) mrerank ssh sh2 f (toPrimitive -> arr) = fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr -mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateP sh x = M_Primitive sh (X.replicate sh x) - -mreplicate :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicate sh x = fromPrimitive (mreplicateP sh x) +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = X.staticShapeFrom (mshape arr) + in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') + (\(sshT :: StaticShX shT) -> + case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = @@ -1352,15 +1362,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" recip = arithPromoteRanked recip (/) = arithPromoteRanked2 (/) instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" exp = arithPromoteRanked exp log = arithPromoteRanked log sqrt = arithPromoteRanked sqrt @@ -1576,14 +1586,20 @@ rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) rrerank ssh sh2 f (rtoPrimitive -> arr) = rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr -rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateP sh x +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shCvtRX sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x | Dict <- lemKnownReplicate (snatFromShR sh) - = Ranked (mreplicateP (shCvtRX sh) x) + = Ranked (mreplicateScalP (shCvtRX sh) x) -rreplicate :: forall n a. PrimElt a - => IShR n -> a -> Ranked n a -rreplicate sh x = rfromPrimitive (rreplicateP sh x) +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr @@ -1654,15 +1670,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" exp = arithPromoteShaped exp log = arithPromoteShaped log sqrt = arithPromoteShaped sqrt @@ -1884,11 +1900,16 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) srerank sh sh2 f (stoPrimitive -> arr) = sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr -sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x) +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemCommMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) -sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicate sh x = sfromPrimitive (sreplicateP sh x) +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a sslice i n@SNat arr = -- cgit v1.2.3-70-g09d2