From 86de413131773f64e1bfd71dd080eb64812a87ee Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 25 May 2024 10:45:54 +0200 Subject: replicate -> replicateScal; add proper generic replicate --- src/Data/Array/Mixed.hs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 748914c..d894b85 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr) unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicate sh x +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') + , Refl <- lemRankApp (staticShapeFrom sh) ssh' + = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) -- cgit v1.2.3-70-g09d2