diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 |
commit | aafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 (patch) | |
tree | c0d0d81a9c40f72adf041b165819ab0c7daa44bf /src/Data/Array/Nested/Internal/Shaped.hs | |
parent | 97ab8502b9cd3f7d908160d13c7d85d23c99e203 (diff) |
Add {m,r,s}dot1Inner
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shaped.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d013959..995507f 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) -sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a -sdot1 = coerce mdot1 +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1' if applicable. +-- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a sdot = coerce mdot |