From 86de413131773f64e1bfd71dd080eb64812a87ee Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sat, 25 May 2024 10:45:54 +0200
Subject: replicate -> replicateScal; add proper generic replicate

---
 src/Data/Array/Mixed.hs           | 13 ++++++--
 src/Data/Array/Nested.hs          |  6 ++--
 src/Data/Array/Nested/Internal.hs | 65 ++++++++++++++++++++++++++-------------
 3 files changed, 57 insertions(+), 27 deletions(-)

(limited to 'src')

diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 748914c..d894b85 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr)
 unScalar :: Storable a => XArray '[] a -> a
 unScalar (XArray a) = S.unScalar a
 
-replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
-replicate sh x
+replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
+replicate sh ssh' (XArray arr)
+  | Dict <- lemKnownNatRankSSX ssh'
+  , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh')
+  , Refl <- lemRankApp (staticShapeFrom sh) ssh'
+  = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $
+            S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $
+              arr)
+
+replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
+replicateScal sh x
   | Dict <- lemKnownNatRank sh
   = XArray (S.constant (shapeLshape sh) x)
 
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 49923d8..968ea18 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -9,7 +9,7 @@ module Data.Array.Nested (
   rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
   rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
   rrerank,
-  rreplicate, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
+  rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
   rslice, rrev1, rreshape, riota,
   -- ** Lifting orthotope operations to 'Ranked' arrays
   rlift, rlift2,
@@ -26,7 +26,7 @@ module Data.Array.Nested (
   sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
   stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
   srerank,
-  sreplicate, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
+  sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
   sslice, srev1, sreshape, siota,
   -- ** Lifting orthotope operations to 'Shaped' arrays
   slift, slift2,
@@ -41,7 +41,7 @@ module Data.Array.Nested (
   mshape, mindex, mindexPartial, mgenerate, msumOuter1,
   mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,
   mrerank,
-  mreplicate, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
+  mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
   mslice, mrev1, mreshape, miota,
   -- ** Lifting orthotope operations to 'Mixed' arrays
   mlift, mlift2,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index c70da54..b7308fa 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -967,12 +967,22 @@ mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
 mrerank ssh sh2 f (toPrimitive -> arr) =
   fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
 
-mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateP sh x = M_Primitive sh (X.replicate sh x)
-
-mreplicate :: forall sh a. PrimElt a
-          => IShX sh -> a -> Mixed sh a
-mreplicate sh x = fromPrimitive (mreplicateP sh x)
+mreplicate :: forall sh sh' a. Elt a
+           => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+  let ssh' = X.staticShapeFrom (mshape arr)
+  in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh')
+           (\(sshT :: StaticShX shT) ->
+              case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+                Refl -> X.replicate sh (ssxAppend ssh' sshT))
+           arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+               => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
 
 mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
 mslice i n arr =
@@ -1352,15 +1362,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where
   negate = arithPromoteRanked negate
   abs = arithPromoteRanked abs
   signum = arithPromoteRanked signum
-  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate"
+  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
 
 instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where
-  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
+  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
   recip = arithPromoteRanked recip
   (/) = arithPromoteRanked2 (/)
 
 instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where
-  pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
+  pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
   exp = arithPromoteRanked exp
   log = arithPromoteRanked log
   sqrt = arithPromoteRanked sqrt
@@ -1576,14 +1586,20 @@ rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
 rrerank ssh sh2 f (rtoPrimitive -> arr) =
   rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
 
-rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateP sh x
+rreplicate :: forall n m a. Elt a
+           => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+  | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat))
+  = Ranked (mreplicate (shCvtRX sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
   | Dict <- lemKnownReplicate (snatFromShR sh)
-  = Ranked (mreplicateP (shCvtRX sh) x)
+  = Ranked (mreplicateScalP (shCvtRX sh) x)
 
-rreplicate :: forall n a. PrimElt a
-          => IShR n -> a -> Ranked n a
-rreplicate sh x = rfromPrimitive (rreplicateP sh x)
+rreplicateScal :: forall n a. PrimElt a
+               => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
 
 rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
 rslice i n arr
@@ -1654,15 +1670,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
   negate = arithPromoteShaped negate
   abs = arithPromoteShaped abs
   signum = arithPromoteShaped signum
-  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate"
+  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
 
 instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where
-  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate"
+  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
   recip = arithPromoteShaped recip
   (/) = arithPromoteShaped2 (/)
 
 instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where
-  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate"
+  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
   exp = arithPromoteShaped exp
   log = arithPromoteShaped log
   sqrt = arithPromoteShaped sqrt
@@ -1884,11 +1900,16 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
 srerank sh sh2 f (stoPrimitive -> arr) =
   sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
 
-sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x)
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+  | Refl <- lemCommMapJustApp sh (Proxy @sh')
+  = Shaped (mreplicate (shCvtSX sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
 
-sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicate sh x = sfromPrimitive (sreplicateP sh x)
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
 
 sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
 sslice i n@SNat arr =
-- 
cgit v1.2.3-70-g09d2