aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith/Lists.hs
blob: 78fe24ad873eaa3660044778d9fb1304c05b9f03 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
{-# LANGUAGE TemplateHaskellQuotes #-}
module Data.Array.Nested.Internal.Arith.Lists where

import Data.Int

import Language.Haskell.TH


data Commutative = Comm | NonComm
  deriving (Show, Eq)

data ArithType = ArithType
  { atType :: Name  -- ''Int32
  , atCName :: String  -- "i32"
  }

typesList :: [ArithType]
typesList =
  [ArithType ''Int32 "i32"
  ,ArithType ''Int64 "i64"
  ,ArithType ''Float "float"
  ,ArithType ''Double "double"
  ]

data ArithBOp = ArithBOp
  { aboName :: String  -- "add"
  , aboComm :: Commutative  -- Comm
  , aboScalFun :: ArithType -> Name  -- \_ -> '(+)
  }

binopsList :: [ArithBOp]
binopsList =
  [ArithBOp "add" Comm (\_ -> '(+))
  ,ArithBOp "sub" NonComm (\_ -> '(-))
  ,ArithBOp "mul" Comm (\_ -> '(*))
  ]

data ArithUOp = ArithUOp
  { auoName :: String  -- "neg"
  }

unopsList :: [ArithUOp]
unopsList =
  [ArithUOp "neg"
  ,ArithUOp "abs"
  ,ArithUOp "signum"
  ]

data ArithRedOp = ArithRedOp
  { aroName :: String  -- "sum"
  }

redopsList :: [ArithRedOp]
redopsList =
  [ArithRedOp "sum1"
  ,ArithRedOp "product1"
  ]