aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-06-19 15:57:43 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-19 15:57:43 +0200
commitaafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 (patch)
treec0d0d81a9c40f72adf041b165819ab0c7daa44bf /src/Data/Array/Nested
parent97ab8502b9cd3f7d908160d13c7d85d23c99e203 (diff)
Add {m,r,s}dot1Inner
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs25
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs48
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs14
4 files changed, 80 insertions, 16 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 215313e..50202ba 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -104,7 +104,7 @@ newtype Primitive a = Primitive a
-- | Element types that are primitive; arrays of these types are just a newtype
-- wrapper over an array.
-class Storable a => PrimElt a where
+class (Storable a, Elt a) => PrimElt a where
fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
@@ -854,15 +854,26 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
-mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
-mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
- numEltDotprod arr1 arr2
+mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
+mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sh1 of
+ _ :$% _
+ | sh1 == sh2
+ , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b))
+ | otherwise -> error "mdot1Inner: Unequal shapes"
+ ZSX -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'mdot1' if applicable.
+-- Prefer 'mdot1Inner' if applicable.
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
-mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a)))
- (fromPrimitive (mflatten (toPrimitive b)))
+mdot a b =
+ munScalar $
+ mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a)))
+ (fromPrimitive (mflatten (toPrimitive b)))
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 74b2186..735d1a3 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -483,11 +483,14 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mmaxIndexPrim arr)
-rdot1 :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a
-rdot1 = coerce mdot1
+rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
+rdot1Inner arr1 arr2
+ | SNat <- rrank arr1
+ , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
+ = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'rdot1' if applicable.
+-- Prefer 'rdot1Inner' if applicable.
rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index ca04840..7077053 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -87,6 +87,16 @@ listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
@@ -166,6 +176,12 @@ ixrHead (IxR list) = listrHead list
ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
@@ -235,6 +251,12 @@ shrHead (ShR list) = listrHead list
shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
@@ -310,17 +332,25 @@ listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is
-listsHead :: ListS (n : sh) i -> i n
+listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i
-listsTail :: ListS (n : sh) i -> ListS sh i
+listsTail :: ListS (n : sh) f -> ListS sh f
listsTail (_ ::$ sh) = sh
+listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
+listsInit (_ ::$ ZS) = ZS
+
+listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
+listsLast (n ::$ ZS) = n
+
listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsRank :: ListS sh i -> SNat (Rank sh)
+listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
@@ -403,6 +433,12 @@ ixsHead (IxS list) = getConst (listsHead list)
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
+ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
+ixsInit (IxS list) = IxS (listsInit list)
+
+ixsLast :: IxS (n : sh) i -> i
+ixsLast (IxS list) = getConst (listsLast list)
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -469,6 +505,12 @@ shsHead (ShS list) = listsHead list
shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)
+shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
+shsInit (ShS list) = ShS (listsInit list)
+
+shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS list) = listsLast list
+
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
shsAppend = coerce (listsAppend @_ @SNat)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index d013959..995507f 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
-sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
-sdot1 = coerce mdot1
+sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
+sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sshape sarr1 of
+ _ :$$ _
+ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
+ -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
+ _ -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'sdot1' if applicable.
+-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
sdot = coerce mdot