summaryrefslogtreecommitdiff
path: root/SC/Defs.hs
blob: fac4e33aa20481423f65bcce8f55b7847cc6244c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
module SC.Defs where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Type

import qualified Language.C as C
import Language.C (Name(..))
import SC.Monad


-- ENVIRONMENTS
-- ------------

data AVarEnv env where
    AVENil :: AVarEnv ()
    AVEPush :: ShNames sh -> ANames t -> AVarEnv env -> AVarEnv (env, Array sh t)

aveprj :: AVarEnv env -> Idx env (Array sh t) -> (ShNames sh, ANames t)
aveprj (AVEPush shn n _) ZeroIdx = (shn, n)
aveprj (AVEPush _ _ aenv) (SuccIdx idx) = aveprj aenv idx

data VarEnv env where
    VENil :: VarEnv ()
    VEPush :: Name -> VarEnv env -> VarEnv (env, t)

veprj :: VarEnv env -> Idx env t -> Name
veprj (VEPush n _) ZeroIdx = n
veprj (VEPush _ env) (SuccIdx idx) = veprj env idx


-- IGNORE TUPLES
-- -------------

data ITup s t where
    ITupPair :: ITup s a -> ITup s b -> ITup s (a, b)
    ITupSingle :: s -> ITup s a
    ITupIgnore :: ITup s a

itupfold :: (forall a. f a) -> (forall a. s -> f a) -> (forall a b. f a -> f b -> f (a, b))
         -> ITup s t -> f t
itupfold z _ _ ITupIgnore = z
itupfold _ f _ (ITupSingle x) = f x
itupfold z f g (ITupPair a b) = g (itupfold z f g a) (itupfold z f g b)

itupmap :: (s1 -> s2) -> ITup s1 t -> ITup s2 t
itupmap f = itupfold ITupIgnore (ITupSingle . f) ITupPair

itupList :: ITup s t -> [s]
itupList (ITupPair t1 t2) = itupList t1 ++ itupList t2
itupList (ITupSingle x) = [x]
itupList ITupIgnore = []

data TypedName = TypedName C.Type Name
type Names = ITup TypedName
type ANames = ITup TypedAName

type Exprs = ITup C.Expr

itupEvars :: ITup TypedName t -> Exprs t
itupEvars = itupmap (\(TypedName _ n) -> C.EVar n)

-- Type is the pointer type of the type that this name is supposed to be according to the type index.
data TypedAName = TypedAName C.Type Name

data TupANames t where
    ANPair :: TupANames a -> TupANames b -> TupANames (a, b)
    ANArray :: ShNames sh -> ITup TypedAName t -> TupANames (Array sh t)
    ANIgnore :: TupANames a

-- Shape names and data array names
tupanamesList :: TupANames t -> ([TypedName], [TypedAName])
tupanamesList (ANPair a b) =
    let (shn1, an1) = tupanamesList a
        (shn2, an2) = tupanamesList b
    in (shn1 ++ shn2, an1 ++ an2)
tupanamesList (ANArray shn ns) = (shnamesList shn, itupList ns)
tupanamesList ANIgnore = ([], [])

data ShNames sh where
    ShZ :: ShNames ()
    ShS :: Name -> ShNames sh -> ShNames (sh, Int)

shnamesList :: ShNames sh -> [TypedName]
shnamesList ShZ = []
shnamesList (ShS n shns) = TypedName (C.TInt C.B64) n : shnamesList shns

makeShNames :: ShapeR sh -> ITup TypedName sh -> ShNames sh
makeShNames ShapeRz ITupIgnore = ShZ
makeShNames (ShapeRsnoc sht) (ITupPair ns (ITupSingle (TypedName _ n))) =
    ShS n (makeShNames sht ns)
makeShNames _ _ = error "wat"

fromShNames :: ShNames sh -> ITup TypedName sh
fromShNames ShZ = ITupIgnore
fromShNames (ShS n ns) = ITupPair (fromShNames ns) (ITupSingle (TypedName (C.TInt C.B64) n))

shNamesShape :: ShNames sh -> ShapeR sh
shNamesShape ShZ = ShapeRz
shNamesShape (ShS _ ns) = ShapeRsnoc (shNamesShape ns)


-- GENERATING VARIABLE NAMES
-- -------------------------

genName :: String -> SC Name
genName prefix = Name . (prefix ++) . show <$> genId


-- TYPE CONVERSION
-- ---------------

cvtType :: ScalarType t -> SC C.Type
cvtType (SingleScalarType (NumSingleType (IntegralNumType it))) = return (cvtIT it)
  where cvtIT :: IntegralType t -> C.Type
        cvtIT TypeInt = C.TInt C.B64
        cvtIT TypeInt8 = C.TInt C.B8
        cvtIT TypeInt16 = C.TInt C.B16
        cvtIT TypeInt32 = C.TInt C.B32
        cvtIT TypeInt64 = C.TInt C.B64
        cvtIT TypeWord = C.TUInt C.B64
        cvtIT TypeWord8 = C.TUInt C.B8
        cvtIT TypeWord16 = C.TUInt C.B16
        cvtIT TypeWord32 = C.TUInt C.B32
        cvtIT TypeWord64 = C.TUInt C.B64
cvtType (SingleScalarType (NumSingleType (FloatingNumType ft))) = cvtFT ft
  where cvtFT :: FloatingType t -> SC C.Type
        cvtFT TypeHalf = throw "Half floats not supported"
        cvtFT TypeFloat = return C.TFloat
        cvtFT TypeDouble = return C.TDouble
cvtType VectorScalarType{} = throw "Vector types not supported"