diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 | 
| commit | aafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 (patch) | |
| tree | c0d0d81a9c40f72adf041b165819ab0c7daa44bf /src/Data | |
| parent | 97ab8502b9cd3f7d908160d13c7d85d23c99e203 (diff) | |
Add {m,r,s}dot1Inner
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 118 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 1 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 16 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Types.hs | 5 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 25 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 48 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 14 | 
10 files changed, 172 insertions, 73 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9f99c3b..fc26633 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -31,6 +31,7 @@ import System.IO.Unsafe  import Data.Array.Mixed.Internal.Arith.Foreign  import Data.Array.Mixed.Internal.Arith.Lists +import Data.Array.Mixed.Types (fromSNat')  -- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition @@ -304,36 +305,44 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))               . VS.toList               <$> VS.unsafeFreeze outvR -vectorDotprodOp :: (Num a, Storable a) -                => (b -> a) -                -> (Ptr a -> Ptr b) -                -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel -                -> (Int64 -> Ptr b -> Ptr b -> IO b)  -- ^ dotprod kernel -                -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b)  -- ^ strided dotprod kernel -                -> RS.Array 1 a -> RS.Array 1 a -> a -vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided -    (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1))) -    (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2))) -  | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2 -  | len1 == 0 = 0  -- if the arrays are empty, just return zero -  | otherwise = case (stride1, stride2) of -      (0, 0) ->  -- replicated scalar * replicated scalar -        fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2) -      (0, 1) ->  -- replicated scalar * dense -        dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) -      (0, -1) ->  -- replicated scalar * reversed dense -        dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice (offset2 - (len1 - 1)) len1 vec2) -      (1, 0) ->  -- dense * replicated scalar -        dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) -      (-1, 0) ->  -- reversed dense * replicated scalar -        dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice (offset1 - (len1 - 1)) len1 vec1) -      (1, 1) ->  -- dense * dense -        dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) -      (-1, -1) ->  -- reversed dense * reversed dense -        dotVectorVector len1 valbackconv ptrconv fdot (VS.slice (offset1 - (len1 - 1)) len1 vec1) (VS.slice (offset2 - (len1 - 1)) len1 vec2) -      (_, _) ->  -- fallback case -        dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2 -vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" +vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) +                     => SNat n +                     -> (a -> b) +                     -> (Ptr a -> Ptr b) +                     -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a)  -- ^ elementwise multiplication +                     -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant +                     -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel +                     -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel +                     -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a +vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner +    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) +    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) +  | null sh1 || null sh2 = error "unreachable" +  | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0]) +  | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty)) +  -- now the input arrays are nonempty +  | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2) +  | last strides1 == 0 = +      fmul sn +        (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1))) +        (vectorRedInnerOp sn valconv ptrconv fscale fred arr2) +  | last strides2 == 0 = +      fmul sn +        (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) +        (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2))) +  -- now there is useful dotprod work along the inner dimension +  | otherwise = unsafePerformIO $ do +      let inrank = fromSNat' sn + 1 +      outv <- VSM.unsafeNew (product (init sh1)) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> +        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> +        VS.unsafeWith vec1 $ \pvec1 -> +        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> +        VS.unsafeWith vec2 $ \pvec2 -> +          fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1) pstrides2 (ptrconv pvec2) +      RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv  {-# NOINLINE dotScalarVector #-}  dotScalarVector :: forall a b. (Num a, Storable a) @@ -461,13 +470,14 @@ $(fmap concat . forM typesList $ \arithtype ->  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -        name = mkName ("dotprodVector" ++ nameBase (atType arithtype)) -        c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype)) -        c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided")) +        name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype)) +        c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype)) +        mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) +        c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))          c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))      sequence [SigD name <$> -                   [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |] -             ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |] +                   [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +             ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]                   return $ FunD name [Clause [] (NormalB body) []]])  -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -533,19 +543,19 @@ intWidBranchExtr fextr32 fextr64    | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64    | otherwise = error "Unsupported Int width" -intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) +intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)                      => -- int32 -                       (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel -                    -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32)  -- ^ dotprod kernel -                    -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32)  -- ^ strided dotprod kernel +                       (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant +                    -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                    -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ dotprod kernel                         -- int64 +                    -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant                      -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel -                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64)  -- ^ dotprod kernel -                    -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64)  -- ^ strided dotprod kernel -                    -> (RS.Array 1 i -> RS.Array 1 i -> i) -intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided -  | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided -  | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided +                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ dotprod kernel +                    -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 +  | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64    | otherwise = error "Unsupported Int width"  class NumElt a where @@ -561,7 +571,7 @@ class NumElt a where    numEltProductFull :: SNat n -> RS.Array n a -> a    numEltMinIndex :: RS.Array n a -> [Int]    numEltMaxIndex :: RS.Array n a -> [Int] -  numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a +  numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a  instance NumElt Int32 where    numEltAdd = addVectorInt32 @@ -576,7 +586,7 @@ instance NumElt Int32 where    numEltProductFull = productFullVectorInt32    numEltMinIndex = minindexVectorInt32    numEltMaxIndex = maxindexVectorInt32 -  numEltDotprod = dotprodVectorInt32 +  numEltDotprodInner = dotprodinnerVectorInt32  instance NumElt Int64 where    numEltAdd = addVectorInt64 @@ -591,7 +601,7 @@ instance NumElt Int64 where    numEltProductFull = productFullVectorInt64    numEltMinIndex = minindexVectorInt64    numEltMaxIndex = maxindexVectorInt64 -  numEltDotprod = dotprodVectorInt64 +  numEltDotprodInner = dotprodinnerVectorInt64  instance NumElt Float where    numEltAdd = addVectorFloat @@ -606,7 +616,7 @@ instance NumElt Float where    numEltProductFull = productFullVectorFloat    numEltMinIndex = minindexVectorFloat    numEltMaxIndex = maxindexVectorFloat -  numEltDotprod = dotprodVectorFloat +  numEltDotprodInner = dotprodinnerVectorFloat  instance NumElt Double where    numEltAdd = addVectorDouble @@ -621,7 +631,7 @@ instance NumElt Double where    numEltProductFull = productFullVectorDouble    numEltMinIndex = minindexVectorDouble    numEltMaxIndex = maxindexVectorDouble -  numEltDotprod = dotprodVectorDouble +  numEltDotprodInner = dotprodinnerVectorDouble  instance NumElt Int where    numEltAdd = intWidBranch2 @Int (+) @@ -646,8 +656,8 @@ instance NumElt Int where    numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))    numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprod = intWidBranchDotprod @Int (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided -                                           (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided +  numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 +                                                (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) @@ -672,8 +682,8 @@ instance NumElt CInt where    numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))    numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprod = intWidBranchDotprod @CInt (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided -                                            (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided +  numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 +                                                 (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64  class FloatElt a where    floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index c1c0070..ade7ce1 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -22,6 +22,7 @@ $(do          ,("extremum_max_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("dotprod_" ++ tyn,               [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])          ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) +        ,("dotprodinner_" ++ tyn,          [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ]    let importsFloat ttyp tyn = diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index 633c9c2..ec7e7bd 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -108,6 +108,9 @@ lemTakeLenApp _ _ _ = unsafeCoerceRefl  lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l  lemInitApp _ _ = unsafeCoerceRefl +lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x +lemLastApp _ _ = unsafeCoerceRefl +  -- ** KnownNat diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index a15e0a2..1285aa1 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -121,6 +121,10 @@ listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f  listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh  listxInit (_ ::% ZX) = ZX +listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast (_ ::% sh@(_ ::% _)) = listxLast sh +listxLast (x ::% ZX) = x +  -- * Mixed indices @@ -179,6 +183,9 @@ ixxDrop = coerce (listxDrop @(Const i) @(Const i))  ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i  ixxInit = coerce (listxInit @(Const i)) +ixxLast :: forall n sh i. IxX (n : sh) i -> i +ixxLast = coerce (listxLast @(Const i)) +  ixxFromLinear :: IShX sh -> Int -> IIxX sh  ixxFromLinear = \sh i -> case go sh i of    (idx, 0) -> idx @@ -330,6 +337,9 @@ shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))  shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i  shxInit = coerce (listxInit @(SMayNat i SNat)) +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i SNat)) +  shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i  shxTakeSSX _ = flip go    where @@ -404,6 +414,12 @@ ssxTail (_ :!% ssh) = ssh  ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'  ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) +ssxInit = coerce (listxInit @(SMayNat () SNat)) + +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat () SNat)) +  -- | This may fail if @sh@ has @Nothing@s in it.  ssxToShX' :: StaticShX sh -> Maybe (IShX sh)  ssxToShX' ZKX = Just ZSX diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index 22d06e5..8e90a88 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -29,6 +29,7 @@ module Data.Array.Mixed.Types (    MapJust,    Tail,    Init, +  Last,    -- * Unsafe    unsafeCoerceRefl, @@ -105,6 +106,10 @@ type family Init l where    Init (x : y : xs) = x : Init (y : xs)    Init '[x] = '[] +type family Last l where +  Last (x : y : xs) = Last (y : xs) +  Last '[x] = x +  -- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to  -- only typecheck for actual type equalities. One cannot, e.g. accidentally diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 53417bd..f37619f 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -12,7 +12,7 @@ module Data.Array.Nested (    rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,    rfromListLinear, rtoListLinear,    rslice, rrev1, rreshape, rflatten, riota, -  rminIndexPrim, rmaxIndexPrim, rdot, rdot1, +  rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,    rnest, runNest,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, rlift2, @@ -33,7 +33,7 @@ module Data.Array.Nested (    sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,    sfromListLinear, stoListLinear,    sslice, srev1, sreshape, sflatten, siota, -  sminIndexPrim, smaxIndexPrim, sdot, sdot1, +  sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,    snest, sunNest,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, slift2, @@ -54,7 +54,7 @@ module Data.Array.Nested (    mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,    mfromListLinear, mtoListLinear,    mslice, mrev1, mreshape, mflatten, miota, -  mminIndexPrim, mmaxIndexPrim, mdot, mdot1, +  mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,    mnest, munNest,    -- ** Lifting orthotope operations to 'Mixed' arrays    mlift, mlift2, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 215313e..50202ba 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -104,7 +104,7 @@ newtype Primitive a = Primitive a  -- | Element types that are primitive; arrays of these types are just a newtype  -- wrapper over an array. -class Storable a => PrimElt a where +class (Storable a, Elt a) => PrimElt a where    fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a    toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) @@ -854,15 +854,26 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =    ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) -mdot1 :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a -mdot1 (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) = -  numEltDotprod arr1 arr2 +mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) +           => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a +mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) +  | Refl <- lemInitApp (Proxy @sh) (Proxy @n) +  , Refl <- lemLastApp (Proxy @sh) (Proxy @n) +  = case sh1 of +      _ :$% _ +        | sh1 == sh2 +        , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> +            fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) +        | otherwise -> error "mdot1Inner: Unequal shapes" +      ZSX -> error "unreachable"  -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'mdot1' if applicable. +-- Prefer 'mdot1Inner' if applicable.  mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a -mdot a b = mdot1 (fromPrimitive (mflatten (toPrimitive a))) -                 (fromPrimitive (mflatten (toPrimitive b))) +mdot a b = +  munScalar $ +    mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) +                     (fromPrimitive (mflatten (toPrimitive b)))  mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)  mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 74b2186..735d1a3 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -483,11 +483,14 @@ rmaxIndexPrim rarr@(Ranked arr)    | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))    = ixCvtXR (mmaxIndexPrim arr) -rdot1 :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a -rdot1 = coerce mdot1 +rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a +rdot1Inner arr1 arr2 +  | SNat <- rrank arr1 +  , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) +  = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2  -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'rdot1' if applicable. +-- Prefer 'rdot1Inner' if applicable.  rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a  rdot = coerce mdot diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index ca04840..7077053 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -87,6 +87,16 @@ listrTail :: ListR (n + 1) i -> ListR n i  listrTail (_ ::: sh) = sh  listrTail ZR = error "unreachable" +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" +  listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i  listrIndex SZ (x ::: _) = x  listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs @@ -166,6 +176,12 @@ ixrHead (IxR list) = listrHead list  ixrTail :: IxR (n + 1) i -> IxR n i  ixrTail (IxR list) = IxR (listrTail list) +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list +  ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i  ixrAppend = coerce (listrAppend @_ @i) @@ -235,6 +251,12 @@ shrHead (ShR list) = listrHead list  shrTail :: ShR (n + 1) i -> ShR n i  shrTail (ShR list) = ShR (listrTail list) +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list +  shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i  shrAppend = coerce (listrAppend @_ @i) @@ -310,17 +332,25 @@ listsToList :: ListS sh (Const i) -> [i]  listsToList ZS = []  listsToList (Const i ::$ is) = i : listsToList is -listsHead :: ListS (n : sh) i -> i n +listsHead :: ListS (n : sh) f -> f n  listsHead (i ::$ _) = i -listsTail :: ListS (n : sh) i -> ListS sh i +listsTail :: ListS (n : sh) f -> ListS sh f  listsTail (_ ::$ sh) = sh +listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh +listsInit (_ ::$ ZS) = ZS + +listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh +listsLast (n ::$ ZS) = n +  listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f  listsAppend ZS idx' = idx'  listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsRank :: ListS sh i -> SNat (Rank sh) +listsRank :: ListS sh f -> SNat (Rank sh)  listsRank ZS = SNat  listsRank (_ ::$ sh) = snatSucc (listsRank sh) @@ -403,6 +433,12 @@ ixsHead (IxS list) = getConst (listsHead list)  ixsTail :: IxS (n : sh) i -> IxS sh i  ixsTail (IxS list) = IxS (listsTail list) +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (IxS list) = IxS (listsInit list) + +ixsLast :: IxS (n : sh) i -> i +ixsLast (IxS list) = getConst (listsLast list) +  ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i  ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -469,6 +505,12 @@ shsHead (ShS list) = listsHead list  shsTail :: ShS (n : sh) -> ShS sh  shsTail (ShS list) = ShS (listsTail list) +shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) +shsInit (ShS list) = ShS (listsInit list) + +shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS list) = listsLast list +  shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')  shsAppend = coerce (listsAppend @_ @SNat) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d013959..995507f 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -418,11 +418,19 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde  smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh  smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) -sdot1 :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a -sdot1 = coerce mdot1 +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) +           => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) +  | Refl <- lemInitApp (Proxy @sh) (Proxy @n) +  , Refl <- lemLastApp (Proxy @sh) (Proxy @n) +  = case sshape sarr1 of +      _ :$$ _ +        | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) +        -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) +      _ -> error "unreachable"  -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1' if applicable. +-- Prefer 'sdot1Inner' if applicable.  sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a  sdot = coerce mdot | 
