diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-06 00:08:40 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-12 22:25:35 +0100 | 
| commit | 766a925698a97cac03e972bdaa2500085be17c65 (patch) | |
| tree | db061ffe8993d5a25eb1c972a98e0917dd6a0fbf /src/Data/Array | |
| parent | 4d0f940f258d9bd0684607f996559d9d47968fdd (diff) | |
Binary ops without normalisation
Before:
> sum(*) Double [1e6] stride 1; -1:  OK
>   68.9 ms ± 4.7 ms
After:
> sum(*) Double [1e6] stride 1; -1:  OK
>   1.44 ms ±  50 μs
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 199 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 30 | 
2 files changed, 153 insertions, 76 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 11ee3fe..fede541 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -67,7 +67,7 @@ liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides                   VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->                     VS.unsafeWith (VS.singleton 1) $ \pstrides ->                       VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> -                      cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) +                       cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)                 RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv    | otherwise = unsafePerformIO $ do        outv <- VSM.unsafeNew (product sh) @@ -79,33 +79,52 @@ liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides        RS.fromVector sh <$> VS.unsafeFreeze outv  -- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: (Storable a, Storable b, Storable c) +liftVEltwise2 :: Storable a                => SNat n -              -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) -              -> RS.Array n a -> RS.Array n b -> RS.Array n c -liftVEltwise2 SNat f +              -> (a -> b) +              -> (Ptr a -> Ptr b) +              -> (a -> a -> a) +              -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ sv +              -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- ^ vs +              -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ vv +              -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv      arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))    | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2    | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))    | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of        (Just (_, 1), Just (_, 1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars -        let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) +        let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))          in RS.A (RG.A sh1 (OI.T strides1 0 vec')) +        (Just (_, 1), Just (blockOff, blockSz)) ->  -- scalar * dense -        RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) -                             (f (Left (vec1 VS.! offset1)) (Right (VS.slice blockOff blockSz vec2))))) +        let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2) +            RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2' +        in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec)) + +      (Just (_, 1), Nothing) ->  -- scalar * array +        wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2 +        (Just (blockOff, blockSz), Just (_, 1)) ->  -- dense * scalar -        RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) -                             (f (Right (VS.slice blockOff blockSz vec1)) (Left (vec2 VS.! offset2))))) +        let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1) +            RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS sn valconv ptrconv f_vs arr1' (vec2 VS.! offset2) +        in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec)) + +      (Nothing, Just (_, 1)) ->  -- array * scalar +        wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2) +        (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) -        | blockSz1 == blockSz2  -- not sure if this check is necessary, might be implied by the below +        | blockSz1 == blockSz2  -- not sure if this check is necessary, might be implied by the strides check          , strides1 == strides2          ->  -- dense * dense but the strides match -          RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) -                               (f (Right (VS.slice blockOff1 blockSz1 vec1)) (Right (VS.slice blockOff2 blockSz2 vec2))))) +          let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) +              arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) +              RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV sn ptrconv f_vv arr1' arr2' +          in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec)) +        (_, _) ->  -- fallback case -        RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) +        wrapBinaryVV sn ptrconv f_vv arr1 arr2  -- | Given shape vector, offset and stride vector, check whether this virtual  -- vector uses a dense subarray of its backing array. If so, the first index @@ -141,6 +160,57 @@ stridesDense sh offsetNeg stridesNeg =            in second ((-s) :) (flipReverseds sh' off' str')      flipReverseds _ _ _ = error "flipReverseds: invalid arguments" +{-# NOINLINE wrapBinarySV #-} +wrapBinarySV :: Storable a +             => SNat n +             -> (a -> b) +             -> (Ptr a -> Ptr b) +             -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) +             -> a -> RS.Array n a +             -> RS.Array n a +wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) = +  unsafePerformIO $ do +    outv <- VSM.unsafeNew (product sh) +    VSM.unsafeWith outv $ \poutv -> +      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> +        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> +          VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> +            cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv) +    RS.fromVector sh <$> VS.unsafeFreeze outv + +wrapBinaryVS :: Storable a +             => SNat n +             -> (a -> b) +             -> (Ptr a -> Ptr b) +             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) +             -> RS.Array n a -> a +             -> RS.Array n a +wrapBinaryVS sn valconv ptrconv cf_strided arr y = +  wrapBinarySV sn valconv ptrconv +               (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr + +-- | This function assumes that the two shapes are equal. +{-# NOINLINE wrapBinaryVV #-} +wrapBinaryVV :: Storable a +             => SNat n +             -> (Ptr a -> Ptr b) +             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) +             -> RS.Array n a -> RS.Array n a +             -> RS.Array n a +wrapBinaryVV sn@SNat ptrconv cf_strided +    (RS.A (RG.A sh (OI.T strides1 offset1 vec1))) +    (RS.A (RG.A _  (OI.T strides2 offset2 vec2))) = +  unsafePerformIO $ do +    outv <- VSM.unsafeNew (product sh) +    VSM.unsafeWith outv $ \poutv -> +      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> +      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 -> +      VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 -> +      VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 -> +      VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 -> +        cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2) +    RS.fromVector sh <$> VS.unsafeFreeze outv +  {-# NOINLINE vectorOp1 #-}  vectorOp1 :: forall a b. Storable a            => (Ptr a -> Ptr b) @@ -286,7 +356,7 @@ vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->                 VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> -                 (`scaleval` fromIntegral multiplier) . valbackconv +                 (`scaleval` multiplier) . valbackconv                     <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)  -- TODO: test this function @@ -423,13 +493,13 @@ $(fmap concat . forM typesList $ \arithtype -> do      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))            cnamebase = "c_binary_" ++ atCName arithtype -          c_ss = varE (aboNumOp arithop) -          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_ss_str = varE (aboNumOp arithop) +          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM floatTypesList $ \arithtype -> do @@ -437,13 +507,13 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))            cnamebase = "c_fbinary_" ++ atCName arithtype -          c_ss = varE (afboNumOp arithop) -          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_ss_str = varE (afboNumOp arithop) +          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +               ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM typesList $ \arithtype -> do @@ -469,6 +539,13 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do  mulWithInt :: Num a => a -> Int -> a  mulWithInt a i = a * fromIntegral i +scaleFromSVStrided :: (Int64 -> Ptr Int64 -> Ptr a -> a -> Ptr Int64 -> Ptr a -> IO ()) +                   -> Int64 -> Ptr a -> a -> Ptr a -> IO () +scaleFromSVStrided fsv n out x ys = +  VS.unsafeWith (VS.singleton n) $ \psh -> +    VS.unsafeWith (VS.singleton 1) $ \pstrides -> +      fsv 1 psh out x pstrides ys +  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -479,10 +556,10 @@ $(fmap concat . forM typesList $ \arithtype -> do            namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))            c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))            c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) -          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) +          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))        sequence [SigD name1 <$>                       [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |] +               ,do body <- [| \sn -> vectorRedInnerOp sn id id (scaleFromSVStrided $c_scale_op) $c_op1 |]                     return $ FunD name1 [Clause [] (NormalB body) []]                 ,SigD namefull <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |] @@ -505,11 +582,11 @@ $(fmap concat . forM typesList $ \arithtype -> do          name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))          c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))          mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) -        c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) +        c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))          c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))      sequence [SigD name <$>                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] -             ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |] +             ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |]                   return $ FunD name [Clause [] (NormalB body) []]])  -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -526,17 +603,17 @@ intWidBranch1 f32 f64 sn  intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)                => (i -> i -> i)  -- ss                   -- int32 -              -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- sv -              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ())  -- vs -              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- vv +              -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- sv +              -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ())  -- vs +              -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- vv                   -- int64 -              -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- sv -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- sv +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv                -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)  intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn -  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) -  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32 +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64    | otherwise = error "Unsupported Int width"  intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -667,55 +744,55 @@ instance NumElt Double where  instance NumElt Int where    numEltAdd = intWidBranch2 @Int (+) -                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) -                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +                (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) +                (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))    numEltSub = intWidBranch2 @Int (-) -                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) -                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +                (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) +                (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))    numEltMul = intWidBranch2 @Int (*) -                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) -                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +                (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) +                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))    numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))    numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))    numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))    numEltSum1Inner = intWidBranchRed1 @Int -                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) -                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) +                      (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) +                      (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM))    numEltProduct1Inner = intWidBranchRed1 @Int -                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) -                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) +                          (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT)) +                          (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT))    numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))    numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))    numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 -                                                (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 +  numEltDotprodInner = intWidBranchDotprod @Int (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 +                                                (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) -                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) -                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +                (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) +                (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))    numEltSub = intWidBranch2 @CInt (-) -                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) -                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +                (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) +                (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))    numEltMul = intWidBranch2 @CInt (*) -                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) -                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +                (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) +                (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))    numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))    numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))    numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))    numEltSum1Inner = intWidBranchRed1 @CInt -                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) -                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) +                      (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) +                      (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM))    numEltProduct1Inner = intWidBranchRed1 @CInt -                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) -                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) +                          (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT)) +                          (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT))    numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))    numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))    numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 -  numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 -                                                 (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 +  numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 +                                                 (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64  class FloatElt a where    floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index a60b717..fa89766 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -12,24 +12,24 @@ import Data.Array.Mixed.Internal.Arith.Lists  $(do    let importsScal ttyp tyn = -        [("binary_" ++ tyn ++ "_vv",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) -        ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |]) -        ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) -        ,("unary_" ++ tyn ++ "_strided",   [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) -        ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) -        ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) -        ,("extremum_min_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) -        ,("extremum_max_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) -        ,("dotprod_" ++ tyn,               [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) -        ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) -        ,("dotprodinner_" ++ tyn,          [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> IO () |]) +        ,("unary_" ++ tyn ++ "_strided",     [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("reduce1_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("reducefull_" ++ tyn,              [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) +        ,("extremum_min_" ++ tyn,            [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("extremum_max_" ++ tyn,            [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("dotprod_" ++ tyn,                 [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) +        ,("dotprod_" ++ tyn ++ "_strided",   [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) +        ,("dotprodinner_" ++ tyn,            [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ]    let importsFloat ttyp tyn = -        [("fbinary_" ++ tyn ++ "_vv",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) -        ,("fbinary_" ++ tyn ++ "_sv",      [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |]) -        ,("fbinary_" ++ tyn ++ "_vs",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) -        ,("funary_" ++ tyn ++ "_strided",  [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp ->                  $ttyp -> IO () |]) +        ,("funary_" ++ tyn ++ "_strided",     [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ]    let generate types imports = | 
