From aafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 19 Jun 2024 15:57:43 +0200 Subject: Add {m,r,s}dot1Inner --- src/Data/Array/Nested/Internal/Shaped.hs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'src/Data/Array/Nested/Internal/Shaped.hs') diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d013959..995507f 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) -sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a -sdot1 = coerce mdot1 +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1' if applicable. +-- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a sdot = coerce mdot -- cgit v1.2.3-70-g09d2