diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-25 10:45:54 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-25 10:45:54 +0200 |
commit | 86de413131773f64e1bfd71dd080eb64812a87ee (patch) | |
tree | fab17164f9e6a3cb5061b7df7759ef200e0b50e8 /src/Data/Array/Mixed.hs | |
parent | d4e328cc5edb171501adc5e6abdfff6e45aace3e (diff) |
replicate -> replicateScal; add proper generic replicate
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 748914c..d894b85 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr) unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicate sh x +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') + , Refl <- lemRankApp (staticShapeFrom sh) ssh' + = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) |