From aafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 19 Jun 2024 15:57:43 +0200 Subject: Add {m,r,s}dot1Inner --- src/Data/Array/Nested/Internal/Mixed.hs | 25 ++++++++++++----- src/Data/Array/Nested/Internal/Ranked.hs | 9 ++++-- src/Data/Array/Nested/Internal/Shape.hs | 48 ++++++++++++++++++++++++++++++-- src/Data/Array/Nested/Internal/Shaped.hs | 14 ++++++++-- 4 files changed, 80 insertions(+), 16 deletions(-) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 215313e..50202ba 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -104,7 +104,7 @@ newtype Primitive a = Primitive a -- | Element types that are primitive; arrays of these types are just a newtype -- wrapper over an array. -class Storable a => PrimElt a where +class (Storable a, Elt a) => PrimElt a where fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) @@ -854,15 +854,26 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) -mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a -mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) = - numEltDotprod arr1 arr2 +mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a +mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sh1 of + _ :$% _ + | sh1 == sh2 + , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) + | otherwise -> error "mdot1Inner: Unequal shapes" + ZSX -> error "unreachable" -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'mdot1' if applicable. +-- Prefer 'mdot1Inner' if applicable. mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a -mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a))) - (fromPrimitive (mflatten (toPrimitive b))) +mdot a b = + munScalar $ + mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) + (fromPrimitive (mflatten (toPrimitive b))) mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 74b2186..735d1a3 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -483,11 +483,14 @@ rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixCvtXR (mmaxIndexPrim arr) -rdot1 :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a -rdot1 = coerce mdot1 +rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a +rdot1Inner arr1 arr2 + | SNat <- rrank arr1 + , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) + = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'rdot1' if applicable. +-- Prefer 'rdot1Inner' if applicable. rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index ca04840..7077053 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -87,6 +87,16 @@ listrTail :: ListR (n + 1) i -> ListR n i listrTail (_ ::: sh) = sh listrTail ZR = error "unreachable" +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" + listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs @@ -166,6 +176,12 @@ ixrHead (IxR list) = listrHead list ixrTail :: IxR (n + 1) i -> IxR n i ixrTail (IxR list) = IxR (listrTail list) +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list + ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) @@ -235,6 +251,12 @@ shrHead (ShR list) = listrHead list shrTail :: ShR (n + 1) i -> ShR n i shrTail (ShR list) = ShR (listrTail list) +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list + shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) @@ -310,17 +332,25 @@ listsToList :: ListS sh (Const i) -> [i] listsToList ZS = [] listsToList (Const i ::$ is) = i : listsToList is -listsHead :: ListS (n : sh) i -> i n +listsHead :: ListS (n : sh) f -> f n listsHead (i ::$ _) = i -listsTail :: ListS (n : sh) i -> ListS sh i +listsTail :: ListS (n : sh) f -> ListS sh f listsTail (_ ::$ sh) = sh +listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh +listsInit (_ ::$ ZS) = ZS + +listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh +listsLast (n ::$ ZS) = n + listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsRank :: ListS sh i -> SNat (Rank sh) +listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) @@ -403,6 +433,12 @@ ixsHead (IxS list) = getConst (listsHead list) ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (IxS list) = IxS (listsInit list) + +ixsLast :: IxS (n : sh) i -> i +ixsLast (IxS list) = getConst (listsLast list) + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -469,6 +505,12 @@ shsHead (ShS list) = listsHead list shsTail :: ShS (n : sh) -> ShS sh shsTail (ShS list) = ShS (listsTail list) +shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) +shsInit (ShS list) = ShS (listsInit list) + +shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS list) = listsLast list + shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') shsAppend = coerce (listsAppend @_ @SNat) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d013959..995507f 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) -sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a -sdot1 = coerce mdot1 +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1' if applicable. +-- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a sdot = coerce mdot -- cgit v1.2.3-70-g09d2