diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 62 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 118 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 10 | 
5 files changed, 177 insertions, 34 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 9a77ccb..7293914 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -14,6 +14,7 @@  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE NoStarIsType #-}  {-# LANGUAGE StrictData #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-} @@ -44,6 +45,8 @@ import GHC.TypeLits  import qualified GHC.TypeNats as TypeNats  import Unsafe.Coerce (unsafeCoerce) +import Data.Array.Nested.Internal.Arith +  -- | Evidence for the constraint @c a@.  data Dict c a where @@ -120,6 +123,10 @@ foldListX f (x ::% xs) = f x <> foldListX f xs  lengthListX :: ListX sh f -> Int  lengthListX = getSum . foldListX (\_ -> Sum 1) +snatLengthListX :: ListX sh f -> SNat (Rank sh) +snatLengthListX ZX = SNat +snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat +  showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS  showListX f l = showString "[" . go "" l . showString "]"    where @@ -419,6 +426,26 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int]  ssxIotaFrom _ ZKX = []  ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where +  Flatten' acc '[] = Just acc +  Flatten' acc (Nothing : sh) = Nothing +  Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) +flattenSh = go (SNat @1) +  where +    go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) +    go acc ZSX = SKnown acc +    go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) +    go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh + +    goUnknown :: Int -> IShX sh -> Int +    goUnknown acc ZSX = acc +    goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh +    goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh +  staticShapeFrom :: IShX sh -> StaticShX sh  staticShapeFrom ZSX = ZKX  staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh @@ -511,6 +538,10 @@ type family AddMaybe n m where  plusSNat :: SNat n -> SNat m -> SNat (n + m)  plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce +-- This should be a function in base +mulSNat :: SNat n -> SNat m -> SNat (n * m) +mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce +  smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)  smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)  smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) @@ -719,17 +750,36 @@ transpose2 ssh1 ssh2 (XArray arr)  sumFull :: (Storable a, Num a) => XArray sh a -> a  sumFull (XArray arr) = S.sumA arr -sumInner :: forall sh sh' a. (Storable a, Num a) +sumInner :: forall sh sh' a. (Storable a, NumElt a)           => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' +sumInner ssh ssh' arr    | Refl <- lemAppNil @sh -  = rerank ssh ssh' ZKX (scalar . sumFull) +  = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) +        sh'F = flattenSh sh' :$% ZSX +        ssh'F = staticShapeFrom sh'F + +        go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a +        go (XArray arr') +          | Refl <- lemRankApp ssh ssh'F +          , let sn = snatLengthListX (let StaticShX l = ssh in l) +          = XArray (numEltSum1Inner sn arr') + +    in go $ +       transpose2 ssh'F ssh $ +       reshapePartial ssh' ssh sh'F $ +       transpose2 ssh ssh' $ +         arr -sumOuter :: forall sh sh' a. (Storable a, Num a) +sumOuter :: forall sh sh' a. (Storable a, NumElt a)           => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' +sumOuter ssh ssh' arr    | Refl <- lemAppNil @sh -  = sumInner ssh' ssh . transpose2 ssh ssh' +  = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) +        shF = flattenSh sh :$% ZSX +    in sumInner ssh' (staticShapeFrom shF) $ +       transpose2 (staticShapeFrom shF) ssh' $ +       reshapePartial ssh ssh' shF $ +         arr  fromListOuter :: forall n sh a. Storable a                => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f3f8f7d..118612f 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -884,13 +884,13 @@ mgenerate sh f = case X.enumShape sh of                    mvecsWrite sh idx val vecs                  mvecsFreeze sh vecs -msumOuter1P :: forall sh n a. (Storable a, Num a) +msumOuter1P :: forall sh n a. (Storable a, NumElt a)              => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)  msumOuter1P (M_Primitive (n :$% sh) arr) =    let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX    in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) -msumOuter1 :: forall sh n a. (Num a, PrimElt a) +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)             => Mixed (n : sh) a -> Mixed sh a  msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive @@ -1466,13 +1466,13 @@ rlift2 :: forall n1 n2 n3 a. Elt a  rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)  rsumOuter1P :: forall n a. -               (Storable a, Num a) +               (Storable a, NumElt a)              => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)  rsumOuter1P (Ranked arr)    | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n    = Ranked (msumOuter1P arr) -rsumOuter1 :: forall n a. (Num a, PrimElt a) +rsumOuter1 :: forall n a. (NumElt a, PrimElt a)             => Ranked (n + 1) a -> Ranked n a  rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive @@ -1748,11 +1748,11 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a         -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a  slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) -ssumOuter1P :: forall sh n a. (Storable a, Num a) +ssumOuter1P :: forall sh n a. (Storable a, NumElt a)              => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)  ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) -ssumOuter1 :: forall sh n a. (Num a, PrimElt a) +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)             => Shaped (n : sh) a -> Shaped sh a  ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4312cd5..bd582b7 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TemplateHaskell #-}  {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-}  {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Nested.Internal.Arith where  import Control.Monad (forM, guard) @@ -25,24 +28,24 @@ import Data.Array.Nested.Internal.Arith.Foreign  import Data.Array.Nested.Internal.Arith.Lists -mliftVEltwise1 :: Storable a -               => SNat n -               -> (VS.Vector a -> VS.Vector a) -               -> RS.Array n a -> RS.Array n a -mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +liftVEltwise1 :: Storable a +              => SNat n +              -> (VS.Vector a -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | Just prefixSz <- stridesDense sh strides =        let vec' = f (VS.slice offset prefixSz vec)        in RS.A (RG.A sh (OI.T strides 0 vec'))    | otherwise = RS.fromVector sh (f (RS.toVector arr)) -mliftVEltwise2 :: Storable a -               => SNat n -               -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) -               -> RS.Array n a -> RS.Array n a -> RS.Array n a -mliftVEltwise2 SNat f +liftVEltwise2 :: Storable a +              => SNat n +              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) -  | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2    | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs    | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of        (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars @@ -62,11 +65,11 @@ mliftVEltwise2 SNat f  stridesDense :: [Int] -> [Int] -> Maybe Int  stridesDense sh _ | any (<= 0) sh = Just 0  stridesDense sh str = -  -- sort dimensions on their stride, ascending -  case sort (zip str sh) of -    [] -> Just 0 +  -- sort dimensions on their stride, ascending, dropping any zero strides +  case dropWhile ((== 0) . fst) (sort (zip str sh)) of +    [] -> Just 1      (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' -    _ -> error "Orthotope array's shape vector and stride vector have different lengths" +    _ -> Nothing  -- if the smallest stride is not 1, it will never be dense    where      -- Given size of currently densely covered region at beginning of the      -- array, the remaining shape vector and the corresponding remaining stride @@ -77,6 +80,7 @@ stridesDense sh str =      checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'      checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" +{-# NOINLINE vectorOp1 #-}  vectorOp1 :: forall a b. Storable a            => (Ptr a -> Ptr b)            -> (Int64 -> Ptr b -> Ptr b -> IO ()) @@ -89,6 +93,7 @@ vectorOp1 ptrconv f v = unsafePerformIO $ do    VS.unsafeFreeze outv  -- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-}  vectorOp2 :: forall a b. Storable a            => (a -> b)            -> (Ptr a -> Ptr b) @@ -127,6 +132,39 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases            VS.unsafeFreeze outv      | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) +                 => SNat n +                 -> (a -> b) +                 -> (Ptr a -> Ptr b) +                 -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) +  | null sh = error "unreachable" +  | any (<= 0) sh = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) +  -- now the input array is nonempty +  | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) +  | last strides == 0 = +      liftVEltwise1 sn +        (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) +        (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) +  -- now there is useful work along the inner dimension +  | otherwise = +      let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those +          (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) +          ndimsF = length shF +      in unsafePerformIO $ do +           outv <- VSM.unsafeNew (product (init shF)) +           VSM.unsafeWith outv $ \poutv -> +             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> +                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> +                   fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) +           RS.fromVector (init sh) <$> VS.unsafeFreeze outv +  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO ()  flipOp f n out v s = f n out s v @@ -138,6 +176,8 @@ class NumElt a where    numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a    numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a    numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) @@ -151,7 +191,7 @@ $(fmap concat . forM typesList $ \arithtype -> do            c_vv = varE $ mkName (cnamebase ++ "_vv")        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM typesList $ \arithtype -> do @@ -161,7 +201,18 @@ $(fmap concat . forM typesList $ \arithtype -> do            c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM redopsList $ \redop -> do +      let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) +          c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]                     return $ FunD name [Clause [] (NormalB body) []]])  -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -171,8 +222,8 @@ intWidBranch1 :: forall i n. (FiniteBits i, Storable i)                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())                -> (SNat n -> RS.Array n i -> RS.Array n i)  intWidBranch1 f32 f64 sn -  | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) -  | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)    | otherwise = error "Unsupported Int width"  intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -187,8 +238,21 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv                -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)  intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn -  | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) -  | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) +                => -- int32 +                   (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                   -- int64 +                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel +                -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64    | otherwise = error "Unsupported Int width"  instance NumElt Int32 where @@ -198,6 +262,8 @@ instance NumElt Int32 where    numEltNeg = negVectorInt32    numEltAbs = absVectorInt32    numEltSignum = signumVectorInt32 +  numEltSum1Inner = sum1VectorInt32 +  numEltProduct1Inner = product1VectorInt32  instance NumElt Int64 where    numEltAdd = addVectorInt64 @@ -206,6 +272,8 @@ instance NumElt Int64 where    numEltNeg = negVectorInt64    numEltAbs = absVectorInt64    numEltSignum = signumVectorInt64 +  numEltSum1Inner = sum1VectorInt64 +  numEltProduct1Inner = product1VectorInt64  instance NumElt Float where    numEltAdd = addVectorFloat @@ -214,6 +282,8 @@ instance NumElt Float where    numEltNeg = negVectorFloat    numEltAbs = absVectorFloat    numEltSignum = signumVectorFloat +  numEltSum1Inner = sum1VectorFloat +  numEltProduct1Inner = product1VectorFloat  instance NumElt Double where    numEltAdd = addVectorDouble @@ -222,6 +292,8 @@ instance NumElt Double where    numEltNeg = negVectorDouble    numEltAbs = absVectorDouble    numEltSignum = signumVectorDouble +  numEltSum1Inner = sum1VectorDouble +  numEltProduct1Inner = product1VectorDouble  instance NumElt Int where    numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -230,6 +302,8 @@ instance NumElt Int where    numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64    numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64    numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 +  numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 +  numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -238,3 +312,5 @@ instance NumElt CInt where    numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64    numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64    numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 +  numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 +  numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index dbd9ddc..f84b1c5 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -22,7 +22,7 @@ $(fmap concat . forM typesList $ \arithtype -> do                   [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])          ,guard (aboComm arithop == NonComm) >>             Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -                    [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +                   [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])          ])  $(fmap concat . forM typesList $ \arithtype -> do @@ -31,3 +31,10 @@ $(fmap concat . forM typesList $ \arithtype -> do        let base = auoName arithop ++ "_" ++ atCName arithtype        ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>          [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    forM redopsList $ \redop -> do +      let base = aroName redop ++ "_" ++ atCName arithtype +      ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +        [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 1b29770..78fe24a 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -45,3 +45,13 @@ unopsList =    ,ArithUOp "abs"    ,ArithUOp "signum"    ] + +data ArithRedOp = ArithRedOp +  { aroName :: String  -- "sum" +  } + +redopsList :: [ArithRedOp] +redopsList = +  [ArithRedOp "sum1" +  ,ArithRedOp "product1" +  ] | 
