From 86de413131773f64e1bfd71dd080eb64812a87ee Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 25 May 2024 10:45:54 +0200 Subject: replicate -> replicateScal; add proper generic replicate --- src/Data/Array/Nested/Internal.hs | 65 ++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 22 deletions(-) (limited to 'src/Data/Array/Nested/Internal.hs') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index c70da54..b7308fa 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -967,12 +967,22 @@ mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) mrerank ssh sh2 f (toPrimitive -> arr) = fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr -mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateP sh x = M_Primitive sh (X.replicate sh x) - -mreplicate :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicate sh x = fromPrimitive (mreplicateP sh x) +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = X.staticShapeFrom (mshape arr) + in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') + (\(sshT :: StaticShX shT) -> + case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = @@ -1352,15 +1362,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" recip = arithPromoteRanked recip (/) = arithPromoteRanked2 (/) instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" exp = arithPromoteRanked exp log = arithPromoteRanked log sqrt = arithPromoteRanked sqrt @@ -1576,14 +1586,20 @@ rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) rrerank ssh sh2 f (rtoPrimitive -> arr) = rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr -rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateP sh x +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shCvtRX sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x | Dict <- lemKnownReplicate (snatFromShR sh) - = Ranked (mreplicateP (shCvtRX sh) x) + = Ranked (mreplicateScalP (shCvtRX sh) x) -rreplicate :: forall n a. PrimElt a - => IShR n -> a -> Ranked n a -rreplicate sh x = rfromPrimitive (rreplicateP sh x) +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr @@ -1654,15 +1670,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" exp = arithPromoteShaped exp log = arithPromoteShaped log sqrt = arithPromoteShaped sqrt @@ -1884,11 +1900,16 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) srerank sh sh2 f (stoPrimitive -> arr) = sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr -sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x) +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemCommMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) -sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicate sh x = sfromPrimitive (sreplicateP sh x) +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a sslice i n@SNat arr = -- cgit v1.2.3-70-g09d2