aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Internal.hs46
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs240
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs33
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs47
4 files changed, 358 insertions, 8 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 588237d..831a9b5 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFoldable #-}
@@ -62,6 +63,7 @@ import Unsafe.Coerce
import Data.Array.Mixed
import qualified Data.Array.Mixed as X
+import Data.Array.Nested.Internal.Arith
-- Invariant in the API
@@ -999,6 +1001,7 @@ mliftPrim2 :: PrimElt a
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
+{-}
instance (Num a, PrimElt a) => Num (Mixed sh a) where
(+) = mliftPrim2 (+)
(-) = mliftPrim2 (-)
@@ -1008,12 +1011,39 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where
signum = mliftPrim signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
-instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where
+type NumConstr a = Num a
+--}
+
+{--}
+mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr))
+
+mliftNumElt2 :: PrimElt a
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+-- TODO: Clean up this mess and remove NumConstr
+type NumConstr a = NumElt a
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 numEltAdd
+ (-) = mliftNumElt2 numEltSub
+ (*) = mliftNumElt2 numEltMul
+ negate = mliftNumElt1 numEltNeg
+ abs = mliftNumElt1 numEltAbs
+ signum = mliftNumElt1 numEltSignum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+--}
+
+instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Mixed sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
recip = mliftPrim recip
(/) = mliftPrim2 (/)
-instance (Floating a, PrimElt a) => Floating (Mixed sh a) where
+instance (Floating a, PrimElt a, NumConstr a) => Floating (Mixed sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
exp = mliftPrim exp
log = mliftPrim log
@@ -1316,7 +1346,7 @@ arithPromoteRanked2 :: forall n a. PrimElt a
-> Ranked n a -> Ranked n a -> Ranked n a
arithPromoteRanked2 = coerce
-instance (Num a, PrimElt a) => Num (Ranked n a) where
+instance (NumConstr a, PrimElt a) => Num (Ranked n a) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
@@ -1325,12 +1355,12 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where
signum = arithPromoteRanked signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate"
-instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where
+instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Ranked n a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate"
recip = arithPromoteRanked recip
(/) = arithPromoteRanked2 (/)
-instance (Floating a, PrimElt a) => Floating (Ranked n a) where
+instance (Floating a, PrimElt a, NumConstr a) => Floating (Ranked n a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate"
exp = arithPromoteRanked exp
log = arithPromoteRanked log
@@ -1616,7 +1646,7 @@ arithPromoteShaped2 :: forall sh a. PrimElt a
-> Shaped sh a -> Shaped sh a -> Shaped sh a
arithPromoteShaped2 = coerce
-instance (Num a, PrimElt a) => Num (Shaped sh a) where
+instance (NumConstr a, PrimElt a) => Num (Shaped sh a) where
(+) = arithPromoteShaped2 (+)
(-) = arithPromoteShaped2 (-)
(*) = arithPromoteShaped2 (*)
@@ -1625,12 +1655,12 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where
signum = arithPromoteShaped signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate"
-instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where
+instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Shaped sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate"
recip = arithPromoteShaped recip
(/) = arithPromoteShaped2 (/)
-instance (Floating a, PrimElt a) => Floating (Shaped sh a) where
+instance (Floating a, PrimElt a, NumConstr a) => Floating (Shaped sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate"
exp = arithPromoteShaped exp
log = arithPromoteShaped log
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
new file mode 100644
index 0000000..4312cd5
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -0,0 +1,240 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Internal.Arith where
+
+import Control.Monad (forM, guard)
+import qualified Data.Array.Internal as OI
+import qualified Data.Array.Internal.RankedG as RG
+import qualified Data.Array.Internal.RankedS as RS
+import Data.Bits
+import Data.Int
+import Data.List (sort)
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.Ptr
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+import Language.Haskell.TH
+import System.IO.Unsafe
+
+import Data.Array.Nested.Internal.Arith.Foreign
+import Data.Array.Nested.Internal.Arith.Lists
+
+
+mliftVEltwise1 :: Storable a
+ => SNat n
+ -> (VS.Vector a -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a
+mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just prefixSz <- stridesDense sh strides =
+ let vec' = f (VS.slice offset prefixSz vec)
+ in RS.A (RG.A sh (OI.T strides 0 vec'))
+ | otherwise = RS.fromVector sh (f (RS.toVector arr))
+
+mliftVEltwise2 :: Storable a
+ => SNat n
+ -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a -> RS.Array n a
+mliftVEltwise2 SNat f
+ arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
+ arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
+ | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
+ | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
+ (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
+ in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
+ (Just 1, Just n) -> -- scalar * dense
+ RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))
+ (Just n, Just 1) -> -- dense * scalar
+ RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))
+ (_, _) -> -- fallback case
+ RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
+
+-- | Given the shape vector and the stride vector, return whether this vector
+-- of strides uses a dense prefix of its backing array. If so, the number of
+-- elements in this prefix is returned.
+-- This excludes any offset.
+stridesDense :: [Int] -> [Int] -> Maybe Int
+stridesDense sh _ | any (<= 0) sh = Just 0
+stridesDense sh str =
+ -- sort dimensions on their stride, ascending
+ case sort (zip str sh) of
+ [] -> Just 0
+ (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
+ _ -> error "Orthotope array's shape vector and stride vector have different lengths"
+ where
+ -- Given size of currently densely covered region at beginning of the
+ -- array, the remaining shape vector and the corresponding remaining stride
+ -- vector, return whether this all together covers a dense prefix of the
+ -- array. If it does, return the number of elements in this prefix.
+ checkCover :: Int -> [Int] -> [Int] -> Maybe Int
+ checkCover block [] [] = Just block
+ checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
+ checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
+
+vectorOp1 :: forall a b. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr b -> IO ())
+ -> VS.Vector a -> VS.Vector a
+vectorOp1 ptrconv f v = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length v)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith v $ \pv ->
+ f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv)
+ VS.unsafeFreeze outv
+
+-- | If two vectors are given, assumes that they have the same length.
+vectorOp2 :: forall a b. Storable a
+ => (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv
+ -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs
+ -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv
+ -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a
+vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
+ (Left x) (Left y) -> VS.singleton (fss x y)
+
+ (Left x) (Right vy) ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vy)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vy $ \pvy ->
+ fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy)
+ VS.unsafeFreeze outv
+
+ (Right vx) (Left y) ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vx)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vx $ \pvx ->
+ fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y)
+ VS.unsafeFreeze outv
+
+ (Right vx) (Right vy)
+ | VS.length vx == VS.length vy ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vx)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vx $ \pvx ->
+ VS.unsafeWith vy $ \pvy ->
+ fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy)
+ VS.unsafeFreeze outv
+ | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
+
+flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
+ -> Int64 -> Ptr a -> Ptr a -> a -> IO ()
+flipOp f n out v s = f n out s v
+
+class NumElt a where
+ numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM binopsList $ \arithop -> do
+ let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype
+ c_ss = varE (aboScalFun arithop arithtype)
+ c_sv = varE $ mkName (cnamebase ++ "_sv")
+ c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs")
+ | otherwise = [| flipOp $c_sv |]
+ c_vv = varE $ mkName (cnamebase ++ "_vv")
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM unopsList $ \arithop -> do
+ let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+-- This branch is ostensibly a runtime branch, but will (hopefully) be
+-- constant-folded away by GHC.
+intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
+ => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
+ -> (SNat n -> RS.Array n i -> RS.Array n i)
+intWidBranch1 f32 f64 sn
+ | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
+ | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> i -> i) -- ss
+ -- int32
+ -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
+ -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
+intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
+ | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
+ | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | otherwise = error "Unsupported Int width"
+
+instance NumElt Int32 where
+ numEltAdd = addVectorInt32
+ numEltSub = subVectorInt32
+ numEltMul = mulVectorInt32
+ numEltNeg = negVectorInt32
+ numEltAbs = absVectorInt32
+ numEltSignum = signumVectorInt32
+
+instance NumElt Int64 where
+ numEltAdd = addVectorInt64
+ numEltSub = subVectorInt64
+ numEltMul = mulVectorInt64
+ numEltNeg = negVectorInt64
+ numEltAbs = absVectorInt64
+ numEltSignum = signumVectorInt64
+
+instance NumElt Float where
+ numEltAdd = addVectorFloat
+ numEltSub = subVectorFloat
+ numEltMul = mulVectorFloat
+ numEltNeg = negVectorFloat
+ numEltAbs = absVectorFloat
+ numEltSignum = signumVectorFloat
+
+instance NumElt Double where
+ numEltAdd = addVectorDouble
+ numEltSub = subVectorDouble
+ numEltMul = mulVectorDouble
+ numEltNeg = negVectorDouble
+ numEltAbs = absVectorDouble
+ numEltSignum = signumVectorDouble
+
+instance NumElt Int where
+ numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
+ numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
+ numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
+ numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64
+ numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64
+ numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64
+
+instance NumElt CInt where
+ numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
+ numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
+ numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
+ numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64
+ numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64
+ numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
new file mode 100644
index 0000000..dbd9ddc
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -0,0 +1,33 @@
+{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Nested.Internal.Arith.Foreign where
+
+import Control.Monad
+import Data.Int
+import Data.Maybe
+import Foreign.Ptr
+import Language.Haskell.TH
+
+import Data.Array.Nested.Internal.Arith.Lists
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM binopsList $ \arithop -> do
+ let base = aboName arithop ++ "_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,guard (aboComm arithop == NonComm) >>
+ Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ forM unopsList $ \arithop -> do
+ let base = auoName arithop ++ "_" ++ atCName arithtype
+ ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
new file mode 100644
index 0000000..1b29770
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -0,0 +1,47 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Internal.Arith.Lists where
+
+import Data.Int
+
+import Language.Haskell.TH
+
+
+data Commutative = Comm | NonComm
+ deriving (Show, Eq)
+
+data ArithType = ArithType
+ { atType :: Name -- ''Int32
+ , atCName :: String -- "i32"
+ }
+
+typesList :: [ArithType]
+typesList =
+ [ArithType ''Int32 "i32"
+ ,ArithType ''Int64 "i64"
+ ,ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
+data ArithBOp = ArithBOp
+ { aboName :: String -- "add"
+ , aboComm :: Commutative -- Comm
+ , aboScalFun :: ArithType -> Name -- \_ -> '(+)
+ }
+
+binopsList :: [ArithBOp]
+binopsList =
+ [ArithBOp "add" Comm (\_ -> '(+))
+ ,ArithBOp "sub" NonComm (\_ -> '(-))
+ ,ArithBOp "mul" Comm (\_ -> '(*))
+ ]
+
+data ArithUOp = ArithUOp
+ { auoName :: String -- "neg"
+ }
+
+unopsList :: [ArithUOp]
+unopsList =
+ [ArithUOp "neg"
+ ,ArithUOp "abs"
+ ,ArithUOp "signum"
+ ]