diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 23:31:59 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 23:32:08 +0200 | 
| commit | 42b8c69a978b54001aeae62c8c37ce80500d6428 (patch) | |
| tree | 2318a443fe2cd659cc432dbc6d7c53e8206ac4e8 | |
| parent | 5f6a64660b16d8f188caca5216e55debf4264611 (diff) | |
Add (temporary version of) more general mdot
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 7 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 7 | 
4 files changed, 23 insertions, 7 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index b5c0772..f75a71c 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -11,7 +11,7 @@ module Data.Array.Nested (    rrerank,    rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,    rslice, rrev1, rreshape, rflatten, riota, -  rminIndexPrim, rmaxIndexPrim, rdot, +  rminIndexPrim, rmaxIndexPrim, rdot, rdot1,    rnest, runNest,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, rlift2, @@ -31,7 +31,7 @@ module Data.Array.Nested (    srerank,    sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,    sslice, srev1, sreshape, sflatten, siota, -  sminIndexPrim, smaxIndexPrim, sdot, +  sminIndexPrim, smaxIndexPrim, sdot, sdot1,    snest, sunNest,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, slift2, @@ -48,7 +48,7 @@ module Data.Array.Nested (    mrerank,    mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,    mslice, mrev1, mreshape, mflatten, miota, -  mminIndexPrim, mmaxIndexPrim, mdot, +  mminIndexPrim, mmaxIndexPrim, mdot, mdot1,    mnest, munNest,    -- ** Lifting orthotope operations to 'Mixed' arrays    mlift, mlift2, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 69df52a..7dbff83 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -812,10 +812,16 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =    ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) -mdot :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a -mdot (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) = +mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a +mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =    numEltDotprod arr1 arr2 +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'mdot1' if applicable. +mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a +mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a))) +                 (fromPrimitive (mflatten (toPrimitive b))) +  mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)  mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 59c1820..1518791 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -464,7 +464,12 @@ rmaxIndexPrim rarr@(Ranked arr)    | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))    = ixCvtXR (mmaxIndexPrim arr) -rdot :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a +rdot1 :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a +rdot1 = coerce mdot1 + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'rdot1' if applicable. +rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a  rdot = coerce mdot  rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 1855015..e453e51 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -386,7 +386,12 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde  smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh  smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) -sdot :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a +sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a +sdot1 = coerce mdot1 + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'sdot1' if applicable. +sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a  sdot = coerce mdot  stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) | 
