diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 56 |
1 files changed, 46 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 7484455..07d5d8a 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -170,16 +170,6 @@ flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v -class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -194,6 +184,20 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_fbinary_" ++ atCName arithtype + c_ss = varE (afboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -204,6 +208,16 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) fmap concat . forM [minBound..maxBound] $ \arithop -> do @@ -255,6 +269,16 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 | otherwise = error "Unsupported Int width" +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + instance NumElt Int32 where numEltAdd = addVectorInt32 numEltSub = subVectorInt32 @@ -334,3 +358,15 @@ instance NumElt CInt where numEltProduct1Inner = intWidBranchRed @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where + floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where + floatEltDiv = divVectorFloat + floatEltRecip = recipVectorFloat + +instance FloatElt Double where + floatEltDiv = divVectorDouble + floatEltRecip = recipVectorDouble |