From 205a20fd581bb7c5728fd457a15e4f78fbee9e75 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Jun 2024 10:02:59 +0200 Subject: Dot product --- src/Data/Array/Nested/Internal/Mixed.hs | 9 +++++++++ src/Data/Array/Nested/Internal/Ranked.hs | 3 +++ src/Data/Array/Nested/Internal/Shaped.hs | 3 +++ 3 files changed, 15 insertions(+) (limited to 'src/Data/Array/Nested/Internal') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index a0de08b..2c99487 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -49,6 +49,11 @@ import Data.Array.Mixed.Types import Data.Array.Mixed.Permutation import Data.Array.Mixed.Lemmas +-- TODO: +-- dotprod, sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +-- After benchmarking: matmul and matvec + + -- Invariant in the API -- ==================== @@ -798,6 +803,10 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) +mdot :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a +mdot (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) = + numEltDotprod arr1 arr2 + mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 589f0c1..c67e892 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -461,6 +461,9 @@ rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixCvtXR (mmaxIndexPrim arr) +rdot :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a +rdot = coerce mdot + rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index ca3fd45..9320495 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -381,6 +381,9 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +sdot :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a +sdot = coerce mdot + stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) -- cgit v1.2.3-70-g09d2