aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs118
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs1
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs3
-rw-r--r--src/Data/Array/Mixed/Shape.hs16
-rw-r--r--src/Data/Array/Mixed/Types.hs5
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs25
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs48
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs14
10 files changed, 172 insertions, 73 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 9f99c3b..fc26633 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -31,6 +31,7 @@ import System.IO.Unsafe
import Data.Array.Mixed.Internal.Arith.Foreign
import Data.Array.Mixed.Internal.Arith.Lists
+import Data.Array.Mixed.Types (fromSNat')
-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
@@ -304,36 +305,44 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
. VS.toList
<$> VS.unsafeFreeze outvR
-vectorDotprodOp :: (Num a, Storable a)
- => (b -> a)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
- -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel
- -> RS.Array 1 a -> RS.Array 1 a -> a
-vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided
- (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1)))
- (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2)))
- | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2
- | len1 == 0 = 0 -- if the arrays are empty, just return zero
- | otherwise = case (stride1, stride2) of
- (0, 0) -> -- replicated scalar * replicated scalar
- fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2)
- (0, 1) -> -- replicated scalar * dense
- dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2)
- (0, -1) -> -- replicated scalar * reversed dense
- dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice (offset2 - (len1 - 1)) len1 vec2)
- (1, 0) -> -- dense * replicated scalar
- dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1)
- (-1, 0) -> -- reversed dense * replicated scalar
- dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice (offset1 - (len1 - 1)) len1 vec1)
- (1, 1) -> -- dense * dense
- dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2)
- (-1, -1) -> -- reversed dense * reversed dense
- dotVectorVector len1 valbackconv ptrconv fdot (VS.slice (offset1 - (len1 - 1)) len1 vec1) (VS.slice (offset2 - (len1 - 1)) len1 vec2)
- (_, _) -> -- fallback case
- dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2
-vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"
+vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
+vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
+ arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
+ arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
+ | null sh1 || null sh2 = error "unreachable"
+ | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0])
+ | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty))
+ -- now the input arrays are nonempty
+ | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2)
+ | last strides1 == 0 =
+ fmul sn
+ (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1)))
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr2)
+ | last strides2 == 0 =
+ fmul sn
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
+ (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2)))
+ -- now there is useful dotprod work along the inner dimension
+ | otherwise = unsafePerformIO $ do
+ let inrank = fromSNat' sn + 1
+ outv <- VSM.unsafeNew (product (init sh1))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1) pstrides2 (ptrconv pvec2)
+ RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
{-# NOINLINE dotScalarVector #-}
dotScalarVector :: forall a b. (Num a, Storable a)
@@ -461,13 +470,14 @@ $(fmap concat . forM typesList $ \arithtype ->
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- name = mkName ("dotprodVector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype))
- c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided"))
+ name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
+ mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
sequence [SigD name <$>
- [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |]
- ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |]
+ [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
-- This branch is ostensibly a runtime branch, but will (hopefully) be
@@ -533,19 +543,19 @@ intWidBranchExtr fextr32 fextr64
| finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
| otherwise = error "Unsupported Int width"
-intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)
+intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
=> -- int32
- (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel
- -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel
+ (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ dotprod kernel
-- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ dotprod kernel
- -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64) -- ^ strided dotprod kernel
- -> (RS.Array 1 i -> RS.Array 1 i -> i)
-intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided
- | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided
- | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel
+ -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i)
+intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
+ | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
| otherwise = error "Unsupported Int width"
class NumElt a where
@@ -561,7 +571,7 @@ class NumElt a where
numEltProductFull :: SNat n -> RS.Array n a -> a
numEltMinIndex :: RS.Array n a -> [Int]
numEltMaxIndex :: RS.Array n a -> [Int]
- numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a
+ numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
instance NumElt Int32 where
numEltAdd = addVectorInt32
@@ -576,7 +586,7 @@ instance NumElt Int32 where
numEltProductFull = productFullVectorInt32
numEltMinIndex = minindexVectorInt32
numEltMaxIndex = maxindexVectorInt32
- numEltDotprod = dotprodVectorInt32
+ numEltDotprodInner = dotprodinnerVectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -591,7 +601,7 @@ instance NumElt Int64 where
numEltProductFull = productFullVectorInt64
numEltMinIndex = minindexVectorInt64
numEltMaxIndex = maxindexVectorInt64
- numEltDotprod = dotprodVectorInt64
+ numEltDotprodInner = dotprodinnerVectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -606,7 +616,7 @@ instance NumElt Float where
numEltProductFull = productFullVectorFloat
numEltMinIndex = minindexVectorFloat
numEltMaxIndex = maxindexVectorFloat
- numEltDotprod = dotprodVectorFloat
+ numEltDotprodInner = dotprodinnerVectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -621,7 +631,7 @@ instance NumElt Double where
numEltProductFull = productFullVectorDouble
numEltMinIndex = minindexVectorDouble
numEltMaxIndex = maxindexVectorDouble
- numEltDotprod = dotprodVectorDouble
+ numEltDotprodInner = dotprodinnerVectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+)
@@ -646,8 +656,8 @@ instance NumElt Int where
numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
- numEltDotprod = intWidBranchDotprod @Int (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided
- (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided
+ numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
@@ -672,8 +682,8 @@ instance NumElt CInt where
numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
- numEltDotprod = intWidBranchDotprod @CInt (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided
- (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided
+ numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
class FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index c1c0070..ade7ce1 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -22,6 +22,7 @@ $(do
,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
+ ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
]
let importsFloat ttyp tyn =
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
index 633c9c2..ec7e7bd 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -108,6 +108,9 @@ lemTakeLenApp _ _ _ = unsafeCoerceRefl
lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
lemInitApp _ _ = unsafeCoerceRefl
+lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
+lemLastApp _ _ = unsafeCoerceRefl
+
-- ** KnownNat
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index a15e0a2..1285aa1 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -121,6 +121,10 @@ listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
listxInit (_ ::% ZX) = ZX
+listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
+listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
+listxLast (x ::% ZX) = x
+
-- * Mixed indices
@@ -179,6 +183,9 @@ ixxDrop = coerce (listxDrop @(Const i) @(Const i))
ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
ixxInit = coerce (listxInit @(Const i))
+ixxLast :: forall n sh i. IxX (n : sh) i -> i
+ixxLast = coerce (listxLast @(Const i))
+
ixxFromLinear :: IShX sh -> Int -> IIxX sh
ixxFromLinear = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -330,6 +337,9 @@ shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
shxInit = coerce (listxInit @(SMayNat i SNat))
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
+shxLast = coerce (listxLast @(SMayNat i SNat))
+
shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
shxTakeSSX _ = flip go
where
@@ -404,6 +414,12 @@ ssxTail (_ :!% ssh) = ssh
ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
+ssxInit = coerce (listxInit @(SMayNat () SNat))
+
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
+ssxLast = coerce (listxLast @(SMayNat () SNat))
+
-- | This may fail if @sh@ has @Nothing@s in it.
ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
ssxToShX' ZKX = Just ZSX
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index 22d06e5..8e90a88 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -29,6 +29,7 @@ module Data.Array.Mixed.Types (
MapJust,
Tail,
Init,
+ Last,
-- * Unsafe
unsafeCoerceRefl,
@@ -105,6 +106,10 @@ type family Init l where
Init (x : y : xs) = x : Init (y : xs)
Init '[x] = '[]
+type family Last l where
+ Last (x : y : xs) = Last (y : xs)
+ Last '[x] = x
+
-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
-- only typecheck for actual type equalities. One cannot, e.g. accidentally
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 53417bd..f37619f 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -12,7 +12,7 @@ module Data.Array.Nested (
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rfromListLinear, rtoListLinear,
rslice, rrev1, rreshape, rflatten, riota,
- rminIndexPrim, rmaxIndexPrim, rdot, rdot1,
+ rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift, rlift2,
@@ -33,7 +33,7 @@ module Data.Array.Nested (
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sfromListLinear, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
- sminIndexPrim, smaxIndexPrim, sdot, sdot1,
+ sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift, slift2,
@@ -54,7 +54,7 @@ module Data.Array.Nested (
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mfromListLinear, mtoListLinear,
mslice, mrev1, mreshape, mflatten, miota,
- mminIndexPrim, mmaxIndexPrim, mdot, mdot1,
+ mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest,
-- ** Lifting orthotope operations to 'Mixed' arrays
mlift, mlift2,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 215313e..50202ba 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -104,7 +104,7 @@ newtype Primitive a = Primitive a
-- | Element types that are primitive; arrays of these types are just a newtype
-- wrapper over an array.
-class Storable a => PrimElt a where
+class (Storable a, Elt a) => PrimElt a where
fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
@@ -854,15 +854,26 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
-mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
-mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
- numEltDotprod arr1 arr2
+mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
+mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sh1 of
+ _ :$% _
+ | sh1 == sh2
+ , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b))
+ | otherwise -> error "mdot1Inner: Unequal shapes"
+ ZSX -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'mdot1' if applicable.
+-- Prefer 'mdot1Inner' if applicable.
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
-mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a)))
- (fromPrimitive (mflatten (toPrimitive b)))
+mdot a b =
+ munScalar $
+ mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a)))
+ (fromPrimitive (mflatten (toPrimitive b)))
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 74b2186..735d1a3 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -483,11 +483,14 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mmaxIndexPrim arr)
-rdot1 :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a
-rdot1 = coerce mdot1
+rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
+rdot1Inner arr1 arr2
+ | SNat <- rrank arr1
+ , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
+ = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'rdot1' if applicable.
+-- Prefer 'rdot1Inner' if applicable.
rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index ca04840..7077053 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -87,6 +87,16 @@ listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
@@ -166,6 +176,12 @@ ixrHead (IxR list) = listrHead list
ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
@@ -235,6 +251,12 @@ shrHead (ShR list) = listrHead list
shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
@@ -310,17 +332,25 @@ listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is
-listsHead :: ListS (n : sh) i -> i n
+listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i
-listsTail :: ListS (n : sh) i -> ListS sh i
+listsTail :: ListS (n : sh) f -> ListS sh f
listsTail (_ ::$ sh) = sh
+listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
+listsInit (_ ::$ ZS) = ZS
+
+listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
+listsLast (n ::$ ZS) = n
+
listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsRank :: ListS sh i -> SNat (Rank sh)
+listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
@@ -403,6 +433,12 @@ ixsHead (IxS list) = getConst (listsHead list)
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
+ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
+ixsInit (IxS list) = IxS (listsInit list)
+
+ixsLast :: IxS (n : sh) i -> i
+ixsLast (IxS list) = getConst (listsLast list)
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -469,6 +505,12 @@ shsHead (ShS list) = listsHead list
shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)
+shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
+shsInit (ShS list) = ShS (listsInit list)
+
+shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS list) = listsLast list
+
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
shsAppend = coerce (listsAppend @_ @SNat)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index d013959..995507f 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
-sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
-sdot1 = coerce mdot1
+sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
+sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sshape sarr1 of
+ _ :$$ _
+ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
+ -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
+ _ -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'sdot1' if applicable.
+-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
sdot = coerce mdot