aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-25 10:45:54 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-25 10:45:54 +0200
commit86de413131773f64e1bfd71dd080eb64812a87ee (patch)
treefab17164f9e6a3cb5061b7df7759ef200e0b50e8 /src
parentd4e328cc5edb171501adc5e6abdfff6e45aace3e (diff)
replicate -> replicateScal; add proper generic replicate
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed.hs13
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs65
3 files changed, 57 insertions, 27 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 748914c..d894b85 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -514,8 +514,17 @@ cast ssh1 sh2 ssh' (XArray arr)
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-replicate :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
-replicate sh x
+replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
+replicate sh ssh' (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh'
+ , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh')
+ , Refl <- lemRankApp (staticShapeFrom sh) ssh'
+ = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $
+ S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $
+ arr)
+
+replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
+replicateScal sh x
| Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh) x)
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 49923d8..968ea18 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -9,7 +9,7 @@ module Data.Array.Nested (
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
- rreplicate, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
+ rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rslice, rrev1, rreshape, riota,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift, rlift2,
@@ -26,7 +26,7 @@ module Data.Array.Nested (
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
srerank,
- sreplicate, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
+ sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sslice, srev1, sreshape, siota,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift, slift2,
@@ -41,7 +41,7 @@ module Data.Array.Nested (
mshape, mindex, mindexPartial, mgenerate, msumOuter1,
mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
- mreplicate, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
+ mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mslice, mrev1, mreshape, miota,
-- ** Lifting orthotope operations to 'Mixed' arrays
mlift, mlift2,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index c70da54..b7308fa 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -967,12 +967,22 @@ mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
mrerank ssh sh2 f (toPrimitive -> arr) =
fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
-mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateP sh x = M_Primitive sh (X.replicate sh x)
-
-mreplicate :: forall sh a. PrimElt a
- => IShX sh -> a -> Mixed sh a
-mreplicate sh x = fromPrimitive (mreplicateP sh x)
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = X.staticShapeFrom (mshape arr)
+ in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
@@ -1352,15 +1362,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate"
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
recip = arithPromoteRanked recip
(/) = arithPromoteRanked2 (/)
instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
exp = arithPromoteRanked exp
log = arithPromoteRanked log
sqrt = arithPromoteRanked sqrt
@@ -1576,14 +1586,20 @@ rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
rrerank ssh sh2 f (rtoPrimitive -> arr) =
rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
-rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateP sh x
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shCvtRX sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
| Dict <- lemKnownReplicate (snatFromShR sh)
- = Ranked (mreplicateP (shCvtRX sh) x)
+ = Ranked (mreplicateScalP (shCvtRX sh) x)
-rreplicate :: forall n a. PrimElt a
- => IShR n -> a -> Ranked n a
-rreplicate sh x = rfromPrimitive (rreplicateP sh x)
+rreplicateScal :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
@@ -1654,15 +1670,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
negate = arithPromoteShaped negate
abs = arithPromoteShaped abs
signum = arithPromoteShaped signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate"
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
recip = arithPromoteShaped recip
(/) = arithPromoteShaped2 (/)
instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
exp = arithPromoteShaped exp
log = arithPromoteShaped log
sqrt = arithPromoteShaped sqrt
@@ -1884,11 +1900,16 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
srerank sh sh2 f (stoPrimitive -> arr) =
sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
-sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x)
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemCommMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shCvtSX sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
-sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicate sh x = sfromPrimitive (sreplicateP sh x)
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =