diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-18 21:55:08 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-18 21:55:08 +0200 |
commit | c65320ad151cb5b92051866d17dcda49c7174e57 (patch) | |
tree | 104f4c69def294ebb7f2e5a1d49be166674fb8ab /src/Data/Array/Mixed/Internal | |
parent | 4a0b2ef27a6e31250c56faef0efc0abf611a0cda (diff) |
More sensible argument order to reduce1 C op
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 2 |
2 files changed, 8 insertions, 8 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index d547084..9f99c3b 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -170,7 +170,7 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a) -> (a -> b) -> (Ptr a -> Ptr b) -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> RS.Array (n + 1) a -> RS.Array n a vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) | null sh = error "unreachable" @@ -208,7 +208,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> - fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR) + fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR) TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> RS.stretch (init sh) -- replicate to original shape . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated @@ -307,7 +307,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) vectorDotprodOp :: (Num a, Storable a) => (b -> a) -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel -> RS.Array 1 a -> RS.Array 1 a -> a @@ -338,7 +338,7 @@ vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" {-# NOINLINE dotScalarVector #-} dotScalarVector :: forall a b. (Num a, Storable a) => Int -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> a -> VS.Vector a -> a dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do alloca @a $ \pout -> do @@ -347,7 +347,7 @@ dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do alloca @Int64 $ \pstride -> do poke pstride 1 VS.unsafeWith vec $ \pvec -> - fred 1 pshape pstride (ptrconv pout) (ptrconv pvec) + fred 1 (ptrconv pout) pshape pstride (ptrconv pvec) res <- peek pout return (scalar * res) @@ -500,7 +500,7 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) => -- int32 (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel -- int64 -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel @@ -535,7 +535,7 @@ intWidBranchExtr fextr32 fextr64 intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) => -- int32 - (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel -- int64 diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ca96093..ef8f3cd 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -53,7 +53,7 @@ $(fmap concat . forM typesList $ \arithtype -> do basefull = "reducefull_" ++ atCName arithtype sequence [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$> - [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |] + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |] ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$> [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]]) |