aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-10 23:31:59 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-10 23:32:08 +0200
commit42b8c69a978b54001aeae62c8c37ce80500d6428 (patch)
tree2318a443fe2cd659cc432dbc6d7c53e8206ac4e8 /src/Data
parent5f6a64660b16d8f188caca5216e55debf4264611 (diff)
Add (temporary version of) more general mdot
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs10
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs7
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs7
4 files changed, 23 insertions, 7 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index b5c0772..f75a71c 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -11,7 +11,7 @@ module Data.Array.Nested (
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rslice, rrev1, rreshape, rflatten, riota,
- rminIndexPrim, rmaxIndexPrim, rdot,
+ rminIndexPrim, rmaxIndexPrim, rdot, rdot1,
rnest, runNest,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift, rlift2,
@@ -31,7 +31,7 @@ module Data.Array.Nested (
srerank,
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sslice, srev1, sreshape, sflatten, siota,
- sminIndexPrim, smaxIndexPrim, sdot,
+ sminIndexPrim, smaxIndexPrim, sdot, sdot1,
snest, sunNest,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift, slift2,
@@ -48,7 +48,7 @@ module Data.Array.Nested (
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mslice, mrev1, mreshape, mflatten, miota,
- mminIndexPrim, mmaxIndexPrim, mdot,
+ mminIndexPrim, mmaxIndexPrim, mdot, mdot1,
mnest, munNest,
-- ** Lifting orthotope operations to 'Mixed' arrays
mlift, mlift2,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 69df52a..7dbff83 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -812,10 +812,16 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
-mdot :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
-mdot (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
+mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
+mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
numEltDotprod arr1 arr2
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'mdot1' if applicable.
+mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
+mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a)))
+ (fromPrimitive (mflatten (toPrimitive b)))
+
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 59c1820..1518791 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -464,7 +464,12 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mmaxIndexPrim arr)
-rdot :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a
+rdot1 :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a
+rdot1 = coerce mdot1
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'rdot1' if applicable.
+rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 1855015..e453e51 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -386,7 +386,12 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
-sdot :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
+sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
+sdot1 = coerce mdot1
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'sdot1' if applicable.
+sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
sdot = coerce mdot
stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)