diff options
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 748914c..d894b85 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr) unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicate sh x +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') + , Refl <- lemRankApp (staticShapeFrom sh) ssh' + = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) |