diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 00:18:17 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 00:18:17 +0200 | 
| commit | a0010622885dcb55a916bf3514c0e9040f6871e9 (patch) | |
| tree | 9e10c18eaf5c873d50e1f88a3bf114179c151769 /src/Data | |
| parent | 4b74d1b1f7c46a4b3907838bee11f669060d3a23 (diff) | |
Fast numeric operations for Num
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 46 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 240 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 33 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 47 | 
4 files changed, 358 insertions, 8 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 588237d..831a9b5 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-}  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE DefaultSignatures #-}  {-# LANGUAGE DeriveFoldable #-} @@ -62,6 +63,7 @@ import Unsafe.Coerce  import Data.Array.Mixed  import qualified Data.Array.Mixed as X +import Data.Array.Nested.Internal.Arith  -- Invariant in the API @@ -999,6 +1001,7 @@ mliftPrim2 :: PrimElt a  mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =    fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) +{-}  instance (Num a, PrimElt a) => Num (Mixed sh a) where    (+) = mliftPrim2 (+)    (-) = mliftPrim2 (-) @@ -1008,12 +1011,39 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where    signum = mliftPrim signum    fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" -instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where +type NumConstr a = Num a +--} + +{--} +mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr)) + +mliftNumElt2 :: PrimElt a +             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) +             -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) +  | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2)) +  | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +-- TODO: Clean up this mess and remove NumConstr +type NumConstr a = NumElt a + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where +  (+) = mliftNumElt2 numEltAdd +  (-) = mliftNumElt2 numEltSub +  (*) = mliftNumElt2 numEltMul +  negate = mliftNumElt1 numEltNeg +  abs = mliftNumElt1 numEltAbs +  signum = mliftNumElt1 numEltSignum +  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" +--} + +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Mixed sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"    recip = mliftPrim recip    (/) = mliftPrim2 (/) -instance (Floating a, PrimElt a) => Floating (Mixed sh a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Mixed sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"    exp = mliftPrim exp    log = mliftPrim log @@ -1316,7 +1346,7 @@ arithPromoteRanked2 :: forall n a. PrimElt a                      -> Ranked n a -> Ranked n a -> Ranked n a  arithPromoteRanked2 = coerce -instance (Num a, PrimElt a) => Num (Ranked n a) where +instance (NumConstr a, PrimElt a) => Num (Ranked n a) where    (+) = arithPromoteRanked2 (+)    (-) = arithPromoteRanked2 (-)    (*) = arithPromoteRanked2 (*) @@ -1325,12 +1355,12 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where    signum = arithPromoteRanked signum    fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" -instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Ranked n a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"    recip = arithPromoteRanked recip    (/) = arithPromoteRanked2 (/) -instance (Floating a, PrimElt a) => Floating (Ranked n a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Ranked n a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"    exp = arithPromoteRanked exp    log = arithPromoteRanked log @@ -1616,7 +1646,7 @@ arithPromoteShaped2 :: forall sh a. PrimElt a                      -> Shaped sh a -> Shaped sh a -> Shaped sh a  arithPromoteShaped2 = coerce -instance (Num a, PrimElt a) => Num (Shaped sh a) where +instance (NumConstr a, PrimElt a) => Num (Shaped sh a) where    (+) = arithPromoteShaped2 (+)    (-) = arithPromoteShaped2 (-)    (*) = arithPromoteShaped2 (*) @@ -1625,12 +1655,12 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where    signum = arithPromoteShaped signum    fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" -instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Shaped sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate"    recip = arithPromoteShaped recip    (/) = arithPromoteShaped2 (/) -instance (Floating a, PrimElt a) => Floating (Shaped sh a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Shaped sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate"    exp = arithPromoteShaped exp    log = arithPromoteShaped log diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs new file mode 100644 index 0000000..4312cd5 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -0,0 +1,240 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Nested.Internal.Arith.Foreign +import Data.Array.Nested.Internal.Arith.Lists + + +mliftVEltwise1 :: Storable a +               => SNat n +               -> (VS.Vector a -> VS.Vector a) +               -> RS.Array n a -> RS.Array n a +mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +  | Just prefixSz <- stridesDense sh strides = +      let vec' = f (VS.slice offset prefixSz vec) +      in RS.A (RG.A sh (OI.T strides 0 vec')) +  | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +mliftVEltwise2 :: Storable a +               => SNat n +               -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) +               -> RS.Array n a -> RS.Array n a -> RS.Array n a +mliftVEltwise2 SNat f +    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) +    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) +  | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs +  | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of +      (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars +        let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) +        in RS.A (RG.A sh1 (OI.T strides1 0 vec')) +      (Just 1, Just n) ->  -- scalar * dense +        RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) +      (Just n, Just 1) ->  -- dense * scalar +        RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) +      (_, _) ->  -- fallback case +        RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = +  -- sort dimensions on their stride, ascending +  case sort (zip str sh) of +    [] -> Just 0 +    (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' +    _ -> error "Orthotope array's shape vector and stride vector have different lengths" +  where +    -- Given size of currently densely covered region at beginning of the +    -- array, the remaining shape vector and the corresponding remaining stride +    -- vector, return whether this all together covers a dense prefix of the +    -- array. If it does, return the number of elements in this prefix. +    checkCover :: Int -> [Int] -> [Int] -> Maybe Int +    checkCover block [] [] = Just block +    checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' +    checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +vectorOp1 :: forall a b. Storable a +          => (Ptr a -> Ptr b) +          -> (Int64 -> Ptr b -> Ptr b -> IO ()) +          -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do +  outv <- VSM.unsafeNew (VS.length v) +  VSM.unsafeWith outv $ \poutv -> +    VS.unsafeWith v $ \pv -> +      f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) +  VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +vectorOp2 :: forall a b. Storable a +          => (a -> b) +          -> (Ptr a -> Ptr b) +          -> (a -> a -> a) +          -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- sv +          -> (Int64 -> Ptr b -> Ptr b -> b -> IO ())  -- vs +          -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ())  -- vv +          -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases +  (Left x) (Left y) -> VS.singleton (fss x y) + +  (Left x) (Right vy) -> +    unsafePerformIO $ do +      outv <- VSM.unsafeNew (VS.length vy) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith vy $ \pvy -> +          fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) +      VS.unsafeFreeze outv + +  (Right vx) (Left y) -> +    unsafePerformIO $ do +      outv <- VSM.unsafeNew (VS.length vx) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith vx $ \pvx -> +          fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) +      VS.unsafeFreeze outv + +  (Right vx) (Right vy) +    | VS.length vx == VS.length vy -> +        unsafePerformIO $ do +          outv <- VSM.unsafeNew (VS.length vx) +          VSM.unsafeWith outv $ \poutv -> +            VS.unsafeWith vx $ \pvx -> +              VS.unsafeWith vy $ \pvy -> +                fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) +          VS.unsafeFreeze outv +    | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) +       ->  Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +class NumElt a where +  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a +  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM binopsList $ \arithop -> do +      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype +          c_ss = varE (aboScalFun arithop arithtype) +          c_sv = varE $ mkName (cnamebase ++ "_sv") +          c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs") +               | otherwise = [| flipOp $c_sv |] +          c_vv = varE $ mkName (cnamebase ++ "_vv") +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM unopsList $ \arithop -> do +      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) +              => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) +              -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn +  | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) +  | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) +  | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) +              => (i -> i -> i)  -- ss +                 -- int32 +              -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- sv +              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ())  -- vs +              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- vv +                 -- int64 +              -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- sv +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv +              -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn +  | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) +  | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | otherwise = error "Unsupported Int width" + +instance NumElt Int32 where +  numEltAdd = addVectorInt32 +  numEltSub = subVectorInt32 +  numEltMul = mulVectorInt32 +  numEltNeg = negVectorInt32 +  numEltAbs = absVectorInt32 +  numEltSignum = signumVectorInt32 + +instance NumElt Int64 where +  numEltAdd = addVectorInt64 +  numEltSub = subVectorInt64 +  numEltMul = mulVectorInt64 +  numEltNeg = negVectorInt64 +  numEltAbs = absVectorInt64 +  numEltSignum = signumVectorInt64 + +instance NumElt Float where +  numEltAdd = addVectorFloat +  numEltSub = subVectorFloat +  numEltMul = mulVectorFloat +  numEltNeg = negVectorFloat +  numEltAbs = absVectorFloat +  numEltSignum = signumVectorFloat + +instance NumElt Double where +  numEltAdd = addVectorDouble +  numEltSub = subVectorDouble +  numEltMul = mulVectorDouble +  numEltNeg = negVectorDouble +  numEltAbs = absVectorDouble +  numEltSignum = signumVectorDouble + +instance NumElt Int where +  numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv +  numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv +  numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv +  numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 +  numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 +  numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 + +instance NumElt CInt where +  numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv +  numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv +  numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv +  numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 +  numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 +  numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..dbd9ddc --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -0,0 +1,33 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Nested.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Nested.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM binopsList $ \arithop -> do +      let base = aboName arithop ++ "_" ++ atCName arithtype +      sequence $ catMaybes +        [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +                 [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +        ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +                 [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,guard (aboComm arithop == NonComm) >> +           Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +                    [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +        ]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    forM unopsList $ \arithop -> do +      let base = auoName arithop ++ "_" ++ atCName arithtype +      ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +        [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs new file mode 100644 index 0000000..1b29770 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Internal.Arith.Lists where + +import Data.Int + +import Language.Haskell.TH + + +data Commutative = Comm | NonComm +  deriving (Show, Eq) + +data ArithType = ArithType +  { atType :: Name  -- ''Int32 +  , atCName :: String  -- "i32" +  } + +typesList :: [ArithType] +typesList = +  [ArithType ''Int32 "i32" +  ,ArithType ''Int64 "i64" +  ,ArithType ''Float "float" +  ,ArithType ''Double "double" +  ] + +data ArithBOp = ArithBOp +  { aboName :: String  -- "add" +  , aboComm :: Commutative  -- Comm +  , aboScalFun :: ArithType -> Name  -- \_ -> '(+) +  } + +binopsList :: [ArithBOp] +binopsList = +  [ArithBOp "add" Comm (\_ -> '(+)) +  ,ArithBOp "sub" NonComm (\_ -> '(-)) +  ,ArithBOp "mul" Comm (\_ -> '(*)) +  ] + +data ArithUOp = ArithUOp +  { auoName :: String  -- "neg" +  } + +unopsList :: [ArithUOp] +unopsList = +  [ArithUOp "neg" +  ,ArithUOp "abs" +  ,ArithUOp "signum" +  ] | 
