aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Shaped.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index d013959..995507f 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
-sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
-sdot1 = coerce mdot1
+sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
+sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sshape sarr1 of
+ _ :$$ _
+ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
+ -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
+ _ -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'sdot1' if applicable.
+-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
sdot = coerce mdot