diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 13:47:18 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 13:47:18 +0200 | 
| commit | 4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch) | |
| tree | 2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Nested/Internal | |
| parent | 827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff) | |
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Nested/Internal')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 118 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 10 | 
3 files changed, 115 insertions, 22 deletions
| diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4312cd5..bd582b7 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TemplateHaskell #-}  {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-}  {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Nested.Internal.Arith where  import Control.Monad (forM, guard) @@ -25,24 +28,24 @@ import Data.Array.Nested.Internal.Arith.Foreign  import Data.Array.Nested.Internal.Arith.Lists -mliftVEltwise1 :: Storable a -               => SNat n -               -> (VS.Vector a -> VS.Vector a) -               -> RS.Array n a -> RS.Array n a -mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +liftVEltwise1 :: Storable a +              => SNat n +              -> (VS.Vector a -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | Just prefixSz <- stridesDense sh strides =        let vec' = f (VS.slice offset prefixSz vec)        in RS.A (RG.A sh (OI.T strides 0 vec'))    | otherwise = RS.fromVector sh (f (RS.toVector arr)) -mliftVEltwise2 :: Storable a -               => SNat n -               -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) -               -> RS.Array n a -> RS.Array n a -> RS.Array n a -mliftVEltwise2 SNat f +liftVEltwise2 :: Storable a +              => SNat n +              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) -  | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2    | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs    | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of        (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars @@ -62,11 +65,11 @@ mliftVEltwise2 SNat f  stridesDense :: [Int] -> [Int] -> Maybe Int  stridesDense sh _ | any (<= 0) sh = Just 0  stridesDense sh str = -  -- sort dimensions on their stride, ascending -  case sort (zip str sh) of -    [] -> Just 0 +  -- sort dimensions on their stride, ascending, dropping any zero strides +  case dropWhile ((== 0) . fst) (sort (zip str sh)) of +    [] -> Just 1      (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' -    _ -> error "Orthotope array's shape vector and stride vector have different lengths" +    _ -> Nothing  -- if the smallest stride is not 1, it will never be dense    where      -- Given size of currently densely covered region at beginning of the      -- array, the remaining shape vector and the corresponding remaining stride @@ -77,6 +80,7 @@ stridesDense sh str =      checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'      checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" +{-# NOINLINE vectorOp1 #-}  vectorOp1 :: forall a b. Storable a            => (Ptr a -> Ptr b)            -> (Int64 -> Ptr b -> Ptr b -> IO ()) @@ -89,6 +93,7 @@ vectorOp1 ptrconv f v = unsafePerformIO $ do    VS.unsafeFreeze outv  -- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-}  vectorOp2 :: forall a b. Storable a            => (a -> b)            -> (Ptr a -> Ptr b) @@ -127,6 +132,39 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases            VS.unsafeFreeze outv      | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) +                 => SNat n +                 -> (a -> b) +                 -> (Ptr a -> Ptr b) +                 -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) +  | null sh = error "unreachable" +  | any (<= 0) sh = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) +  -- now the input array is nonempty +  | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) +  | last strides == 0 = +      liftVEltwise1 sn +        (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) +        (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) +  -- now there is useful work along the inner dimension +  | otherwise = +      let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those +          (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) +          ndimsF = length shF +      in unsafePerformIO $ do +           outv <- VSM.unsafeNew (product (init shF)) +           VSM.unsafeWith outv $ \poutv -> +             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> +                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> +                   fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) +           RS.fromVector (init sh) <$> VS.unsafeFreeze outv +  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO ()  flipOp f n out v s = f n out s v @@ -138,6 +176,8 @@ class NumElt a where    numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a    numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a    numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) @@ -151,7 +191,7 @@ $(fmap concat . forM typesList $ \arithtype -> do            c_vv = varE $ mkName (cnamebase ++ "_vv")        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM typesList $ \arithtype -> do @@ -161,7 +201,18 @@ $(fmap concat . forM typesList $ \arithtype -> do            c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM redopsList $ \redop -> do +      let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) +          c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]                     return $ FunD name [Clause [] (NormalB body) []]])  -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -171,8 +222,8 @@ intWidBranch1 :: forall i n. (FiniteBits i, Storable i)                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())                -> (SNat n -> RS.Array n i -> RS.Array n i)  intWidBranch1 f32 f64 sn -  | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) -  | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)    | otherwise = error "Unsupported Int width"  intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -187,8 +238,21 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv                -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)  intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn -  | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) -  | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) +                => -- int32 +                   (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                   -- int64 +                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel +                -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64    | otherwise = error "Unsupported Int width"  instance NumElt Int32 where @@ -198,6 +262,8 @@ instance NumElt Int32 where    numEltNeg = negVectorInt32    numEltAbs = absVectorInt32    numEltSignum = signumVectorInt32 +  numEltSum1Inner = sum1VectorInt32 +  numEltProduct1Inner = product1VectorInt32  instance NumElt Int64 where    numEltAdd = addVectorInt64 @@ -206,6 +272,8 @@ instance NumElt Int64 where    numEltNeg = negVectorInt64    numEltAbs = absVectorInt64    numEltSignum = signumVectorInt64 +  numEltSum1Inner = sum1VectorInt64 +  numEltProduct1Inner = product1VectorInt64  instance NumElt Float where    numEltAdd = addVectorFloat @@ -214,6 +282,8 @@ instance NumElt Float where    numEltNeg = negVectorFloat    numEltAbs = absVectorFloat    numEltSignum = signumVectorFloat +  numEltSum1Inner = sum1VectorFloat +  numEltProduct1Inner = product1VectorFloat  instance NumElt Double where    numEltAdd = addVectorDouble @@ -222,6 +292,8 @@ instance NumElt Double where    numEltNeg = negVectorDouble    numEltAbs = absVectorDouble    numEltSignum = signumVectorDouble +  numEltSum1Inner = sum1VectorDouble +  numEltProduct1Inner = product1VectorDouble  instance NumElt Int where    numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -230,6 +302,8 @@ instance NumElt Int where    numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64    numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64    numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 +  numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 +  numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -238,3 +312,5 @@ instance NumElt CInt where    numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64    numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64    numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 +  numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 +  numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index dbd9ddc..f84b1c5 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -22,7 +22,7 @@ $(fmap concat . forM typesList $ \arithtype -> do                   [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])          ,guard (aboComm arithop == NonComm) >>             Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -                    [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +                   [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])          ])  $(fmap concat . forM typesList $ \arithtype -> do @@ -31,3 +31,10 @@ $(fmap concat . forM typesList $ \arithtype -> do        let base = auoName arithop ++ "_" ++ atCName arithtype        ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>          [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    forM redopsList $ \redop -> do +      let base = aroName redop ++ "_" ++ atCName arithtype +      ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +        [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 1b29770..78fe24a 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -45,3 +45,13 @@ unopsList =    ,ArithUOp "abs"    ,ArithUOp "signum"    ] + +data ArithRedOp = ArithRedOp +  { aroName :: String  -- "sum" +  } + +redopsList :: [ArithRedOp] +redopsList = +  [ArithRedOp "sum1" +  ,ArithRedOp "product1" +  ] | 
