aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-10 10:02:59 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:17:09 +0200
commit205a20fd581bb7c5728fd457a15e4f78fbee9e75 (patch)
treef6669ea87b56dde0f6168c109b3d7d7fcbf06136 /src/Data/Array/Nested
parentc211316a4ab43cf34d6567c6919a3922d5840ae0 (diff)
Dot product
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs3
3 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index a0de08b..2c99487 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -49,6 +49,11 @@ import Data.Array.Mixed.Types
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Lemmas
+-- TODO:
+-- dotprod, sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+-- After benchmarking: matmul and matvec
+
+
-- Invariant in the API
-- ====================
@@ -798,6 +803,10 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
+mdot :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
+mdot (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
+ numEltDotprod arr1 arr2
+
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 589f0c1..c67e892 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -461,6 +461,9 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mmaxIndexPrim arr)
+rdot :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a
+rdot = coerce mdot
+
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index ca3fd45..9320495 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -381,6 +381,9 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+sdot :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
+sdot = coerce mdot
+
stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)