aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-06-19 15:57:43 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-19 15:57:43 +0200
commitaafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 (patch)
treec0d0d81a9c40f72adf041b165819ab0c7daa44bf /src/Data/Array/Nested/Internal/Mixed.hs
parent97ab8502b9cd3f7d908160d13c7d85d23c99e203 (diff)
Add {m,r,s}dot1Inner
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs25
1 files changed, 18 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 215313e..50202ba 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -104,7 +104,7 @@ newtype Primitive a = Primitive a
-- | Element types that are primitive; arrays of these types are just a newtype
-- wrapper over an array.
-class Storable a => PrimElt a where
+class (Storable a, Elt a) => PrimElt a where
fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
@@ -854,15 +854,26 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
-mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
-mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
- numEltDotprod arr1 arr2
+mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
+mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sh1 of
+ _ :$% _
+ | sh1 == sh2
+ , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b))
+ | otherwise -> error "mdot1Inner: Unequal shapes"
+ ZSX -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'mdot1' if applicable.
+-- Prefer 'mdot1Inner' if applicable.
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
-mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a)))
- (fromPrimitive (mflatten (toPrimitive b)))
+mdot a b =
+ munScalar $
+ mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a)))
+ (fromPrimitive (mflatten (toPrimitive b)))
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)