From 4c86a3a4231cecc5b7c31491398f43b4ba667eea Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 23 May 2024 13:47:18 +0200 Subject: Fast sum Also fast product, but that's currently unused --- src/Data/Array/Mixed.hs | 64 +++++++++++-- src/Data/Array/Nested/Internal.hs | 12 +-- src/Data/Array/Nested/Internal/Arith.hs | 118 +++++++++++++++++++----- src/Data/Array/Nested/Internal/Arith/Foreign.hs | 9 +- src/Data/Array/Nested/Internal/Arith/Lists.hs | 10 ++ 5 files changed, 178 insertions(+), 35 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 9a77ccb..7293914 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -14,6 +14,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -44,6 +45,8 @@ import GHC.TypeLits import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce (unsafeCoerce) +import Data.Array.Nested.Internal.Arith + -- | Evidence for the constraint @c a@. data Dict c a where @@ -120,6 +123,10 @@ foldListX f (x ::% xs) = f x <> foldListX f xs lengthListX :: ListX sh f -> Int lengthListX = getSum . foldListX (\_ -> Sum 1) +snatLengthListX :: ListX sh f -> SNat (Rank sh) +snatLengthListX ZX = SNat +snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat + showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS showListX f l = showString "[" . go "" l . showString "]" where @@ -419,6 +426,26 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKX = [] ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) +flattenSh = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + staticShapeFrom :: IShX sh -> StaticShX sh staticShapeFrom ZSX = ZKX staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh @@ -511,6 +538,10 @@ type family AddMaybe n m where plusSNat :: SNat n -> SNat m -> SNat (n + m) plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce +-- This should be a function in base +mulSNat :: SNat n -> SNat m -> SNat (n * m) +mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce + smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) @@ -719,17 +750,36 @@ transpose2 ssh1 ssh2 (XArray arr) sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr -sumInner :: forall sh sh' a. (Storable a, Num a) +sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' +sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = rerank ssh ssh' ZKX (scalar . sumFull) - -sumOuter :: forall sh sh' a. (Storable a, Num a) + = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = flattenSh sh' :$% ZSX + ssh'F = staticShapeFrom sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = snatLengthListX (let StaticShX l = ssh in l) + = XArray (numEltSum1Inner sn arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' +sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = sumInner ssh' ssh . transpose2 ssh ssh' + = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = flattenSh sh :$% ZSX + in sumInner ssh' (staticShapeFrom shF) $ + transpose2 (staticShapeFrom shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f3f8f7d..118612f 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -884,13 +884,13 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -msumOuter1P :: forall sh n a. (Storable a, Num a) +msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) -msumOuter1 :: forall sh n a. (Num a, PrimElt a) +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive @@ -1466,13 +1466,13 @@ rlift2 :: forall n1 n2 n3 a. Elt a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) rsumOuter1P :: forall n a. - (Storable a, Num a) + (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked (msumOuter1P arr) -rsumOuter1 :: forall n a. (Num a, PrimElt a) +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive @@ -1748,11 +1748,11 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) -ssumOuter1P :: forall sh n a. (Storable a, Num a) +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) -ssumOuter1 :: forall sh n a. (Num a, PrimElt a) +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4312cd5..bd582b7 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Arith where import Control.Monad (forM, guard) @@ -25,24 +28,24 @@ import Data.Array.Nested.Internal.Arith.Foreign import Data.Array.Nested.Internal.Arith.Lists -mliftVEltwise1 :: Storable a - => SNat n - -> (VS.Vector a -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +liftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | Just prefixSz <- stridesDense sh strides = let vec' = f (VS.slice offset prefixSz vec) in RS.A (RG.A sh (OI.T strides 0 vec')) | otherwise = RS.fromVector sh (f (RS.toVector arr)) -mliftVEltwise2 :: Storable a - => SNat n - -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -> RS.Array n a -mliftVEltwise2 SNat f +liftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars @@ -62,11 +65,11 @@ mliftVEltwise2 SNat f stridesDense :: [Int] -> [Int] -> Maybe Int stridesDense sh _ | any (<= 0) sh = Just 0 stridesDense sh str = - -- sort dimensions on their stride, ascending - case sort (zip str sh) of - [] -> Just 0 + -- sort dimensions on their stride, ascending, dropping any zero strides + case dropWhile ((== 0) . fst) (sort (zip str sh)) of + [] -> Just 1 (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' - _ -> error "Orthotope array's shape vector and stride vector have different lengths" + _ -> Nothing -- if the smallest stride is not 1, it will never be dense where -- Given size of currently densely covered region at beginning of the -- array, the remaining shape vector and the corresponding remaining stride @@ -77,6 +80,7 @@ stridesDense sh str = checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" +{-# NOINLINE vectorOp1 #-} vectorOp1 :: forall a b. Storable a => (Ptr a -> Ptr b) -> (Int64 -> Ptr b -> Ptr b -> IO ()) @@ -89,6 +93,7 @@ vectorOp1 ptrconv f v = unsafePerformIO $ do VS.unsafeFreeze outv -- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} vectorOp2 :: forall a b. Storable a => (a -> b) -> (Ptr a -> Ptr b) @@ -127,6 +132,39 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases VS.unsafeFreeze outv | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) + | null sh = error "unreachable" + | any (<= 0) sh = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) + -- now the input array is nonempty + | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) + | last strides == 0 = + liftVEltwise1 sn + (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) + (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) + -- now there is useful work along the inner dimension + | otherwise = + let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those + (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) + ndimsF = length shF + in unsafePerformIO $ do + outv <- VSM.unsafeNew (product (init shF)) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> + fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) + RS.fromVector (init sh) <$> VS.unsafeFreeze outv + flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v @@ -138,6 +176,8 @@ class NumElt a where numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) @@ -151,7 +191,7 @@ $(fmap concat . forM typesList $ \arithtype -> do c_vv = varE $ mkName (cnamebase ++ "_vv") sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] return $ FunD name [Clause [] (NormalB body) []]]) $(fmap concat . forM typesList $ \arithtype -> do @@ -161,7 +201,18 @@ $(fmap concat . forM typesList $ \arithtype -> do c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM redopsList $ \redop -> do + let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) + c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -171,8 +222,8 @@ intWidBranch1 :: forall i n. (FiniteBits i, Storable i) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -> (SNat n -> RS.Array n i -> RS.Array n i) intWidBranch1 f32 f64 sn - | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) - | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) | otherwise = error "Unsupported Int width" intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -187,8 +238,21 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) - | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 | otherwise = error "Unsupported Int width" instance NumElt Int32 where @@ -198,6 +262,8 @@ instance NumElt Int32 where numEltNeg = negVectorInt32 numEltAbs = absVectorInt32 numEltSignum = signumVectorInt32 + numEltSum1Inner = sum1VectorInt32 + numEltProduct1Inner = product1VectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -206,6 +272,8 @@ instance NumElt Int64 where numEltNeg = negVectorInt64 numEltAbs = absVectorInt64 numEltSignum = signumVectorInt64 + numEltSum1Inner = sum1VectorInt64 + numEltProduct1Inner = product1VectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -214,6 +282,8 @@ instance NumElt Float where numEltNeg = negVectorFloat numEltAbs = absVectorFloat numEltSignum = signumVectorFloat + numEltSum1Inner = sum1VectorFloat + numEltProduct1Inner = product1VectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -222,6 +292,8 @@ instance NumElt Double where numEltNeg = negVectorDouble numEltAbs = absVectorDouble numEltSignum = signumVectorDouble + numEltSum1Inner = sum1VectorDouble + numEltProduct1Inner = product1VectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -230,6 +302,8 @@ instance NumElt Int where numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 + numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 + numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -238,3 +312,5 @@ instance NumElt CInt where numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 + numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 + numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index dbd9ddc..f84b1c5 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -22,7 +22,7 @@ $(fmap concat . forM typesList $ \arithtype -> do [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) ,guard (aboComm arithop == NonComm) >> Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) ]) $(fmap concat . forM typesList $ \arithtype -> do @@ -31,3 +31,10 @@ $(fmap concat . forM typesList $ \arithtype -> do let base = auoName arithop ++ "_" ++ atCName arithtype ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + forM redopsList $ \redop -> do + let base = aroName redop ++ "_" ++ atCName arithtype + ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 1b29770..78fe24a 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -45,3 +45,13 @@ unopsList = ,ArithUOp "abs" ,ArithUOp "signum" ] + +data ArithRedOp = ArithRedOp + { aroName :: String -- "sum" + } + +redopsList :: [ArithRedOp] +redopsList = + [ArithRedOp "sum1" + ,ArithRedOp "product1" + ] -- cgit v1.2.3-70-g09d2