aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-25 10:45:54 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-25 10:45:54 +0200
commit86de413131773f64e1bfd71dd080eb64812a87ee (patch)
treefab17164f9e6a3cb5061b7df7759ef200e0b50e8 /src/Data/Array/Mixed.hs
parentd4e328cc5edb171501adc5e6abdfff6e45aace3e (diff)
replicate -> replicateScal; add proper generic replicate
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs13
1 files changed, 11 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 748914c..d894b85 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr)
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
-replicate sh x
+replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
+replicate sh ssh' (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh'
+ , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh')
+ , Refl <- lemRankApp (staticShapeFrom sh) ssh'
+ = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $
+ S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $
+ arr)
+
+replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
+replicateScal sh x
| Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh) x)