aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-25 10:45:54 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-25 10:45:54 +0200
commit86de413131773f64e1bfd71dd080eb64812a87ee (patch)
treefab17164f9e6a3cb5061b7df7759ef200e0b50e8 /src/Data/Array/Nested
parentd4e328cc5edb171501adc5e6abdfff6e45aace3e (diff)
replicate -> replicateScal; add proper generic replicate
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs65
1 files changed, 43 insertions, 22 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index c70da54..b7308fa 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -967,12 +967,22 @@ mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
mrerank ssh sh2 f (toPrimitive -> arr) =
fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
-mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateP sh x = M_Primitive sh (X.replicate sh x)
-
-mreplicate :: forall sh a. PrimElt a
- => IShX sh -> a -> Mixed sh a
-mreplicate sh x = fromPrimitive (mreplicateP sh x)
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = X.staticShapeFrom (mshape arr)
+ in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
@@ -1352,15 +1362,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate"
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
instance (NumElt a, PrimElt a, Fractional a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
recip = arithPromoteRanked recip
(/) = arithPromoteRanked2 (/)
instance (NumElt a, PrimElt a, Floating a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
exp = arithPromoteRanked exp
log = arithPromoteRanked log
sqrt = arithPromoteRanked sqrt
@@ -1576,14 +1586,20 @@ rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
rrerank ssh sh2 f (rtoPrimitive -> arr) =
rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
-rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateP sh x
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shCvtRX sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
| Dict <- lemKnownReplicate (snatFromShR sh)
- = Ranked (mreplicateP (shCvtRX sh) x)
+ = Ranked (mreplicateScalP (shCvtRX sh) x)
-rreplicate :: forall n a. PrimElt a
- => IShR n -> a -> Ranked n a
-rreplicate sh x = rfromPrimitive (rreplicateP sh x)
+rreplicateScal :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
@@ -1654,15 +1670,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
negate = arithPromoteShaped negate
abs = arithPromoteShaped abs
signum = arithPromoteShaped signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate"
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
instance (NumElt a, PrimElt a, Fractional a) => Fractional (Shaped sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
recip = arithPromoteShaped recip
(/) = arithPromoteShaped2 (/)
instance (NumElt a, PrimElt a, Floating a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
exp = arithPromoteShaped exp
log = arithPromoteShaped log
sqrt = arithPromoteShaped sqrt
@@ -1884,11 +1900,16 @@ srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
srerank sh sh2 f (stoPrimitive -> arr) =
sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
-sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x)
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemCommMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shCvtSX sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
-sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicate sh x = sfromPrimitive (sreplicateP sh x)
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =