diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 22:09:50 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 22:09:50 +0100 | 
| commit | 984e5315768dd190a97069167daf970c17c3c867 (patch) | |
| tree | 7db8a4173a4198ba0a3fafa54799fd7273f0bfdb /src/Data | |
| parent | 37eec011de921504dc16fd16ec9bb0e5008347fd (diff) | |
arith: Unary float ops on strided arrays without normalisation
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 2 | 
2 files changed, 3 insertions, 3 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a403d3c..11ee3fe 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -460,10 +460,10 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) +          c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +               ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]                     return $ FunD name [Clause [] (NormalB body) []]])  mulWithInt :: Num a => a -> Int -> a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index b53eb36..a60b717 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -29,7 +29,7 @@ $(do          [("fbinary_" ++ tyn ++ "_vv",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])          ,("fbinary_" ++ tyn ++ "_sv",      [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])          ,("fbinary_" ++ tyn ++ "_vs",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) -        ,("funary_" ++ tyn,                [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("funary_" ++ tyn ++ "_strided",  [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ]    let generate types imports = | 
