diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 |
commit | aafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 (patch) | |
tree | c0d0d81a9c40f72adf041b165819ab0c7daa44bf /src/Data/Array/Nested/Internal/Mixed.hs | |
parent | 97ab8502b9cd3f7d908160d13c7d85d23c99e203 (diff) |
Add {m,r,s}dot1Inner
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 215313e..50202ba 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -104,7 +104,7 @@ newtype Primitive a = Primitive a -- | Element types that are primitive; arrays of these types are just a newtype -- wrapper over an array. -class Storable a => PrimElt a where +class (Storable a, Elt a) => PrimElt a where fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) @@ -854,15 +854,26 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) -mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a -mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) = - numEltDotprod arr1 arr2 +mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a +mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sh1 of + _ :$% _ + | sh1 == sh2 + , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) + | otherwise -> error "mdot1Inner: Unequal shapes" + ZSX -> error "unreachable" -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'mdot1' if applicable. +-- Prefer 'mdot1Inner' if applicable. mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a -mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a))) - (fromPrimitive (mflatten (toPrimitive b))) +mdot a b = + munScalar $ + mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) + (fromPrimitive (mflatten (toPrimitive b))) mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) |