summaryrefslogtreecommitdiff
path: root/SC/Exp.hs
blob: d033cc808a2a1cfbc6749986ac1587188184210d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
module SC.Exp where

import qualified Data.Array.Accelerate.AST as A
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import qualified Language.C as C
import SC.Defs
import SC.Monad


data CompiledFun aenv t1 t2 =
    CompiledFun
        C.FunDef  -- ^ expression function implementation
        (Exprs t1 -> Names t2 -> [C.Expr])
            -- ^ arguments builder. Given:
            --   - expressions that compute the direct arguments;
            --   - names that the output values should be stored in;
            --   returns the list of arguments to be passed to the compiled
            --   function. The outputs will be stored by storing to pointers to
            --   the given names.
            --   The arguments will refer to array variable names found in the
            --   original array environment.

-- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith).
compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2)
compileFun aenv (A.Lam lhs (A.Body body)) = do
    funname <- genName "expfun_"
    (argnames, env) <- genVarsEnv lhs VENil
    outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n)
                    <$> genVars (A.expType body)
    (usedA, res) <- compileExp' aenv env body
    (sts1, retexprs) <- toStExprs (A.expType body) res
    let sts2 = genoutstores outnames retexprs
        arguments =
            map (\(TypedAName t n) -> (t, n)) usedA
            ++ map (\(TypedName t n) -> (t, n)) (itupList argnames)
            ++ map (\(TypedName t n) -> (t, n)) (itupList outnames)
    return $ CompiledFun
        (C.ProcDef funname arguments (sts1 ++ sts2))
        (\argexprs destnames ->
            map (\(TypedAName _ n) -> C.EVar n) usedA
            ++ itupList argexprs
            ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames))
  where
    genoutstores :: Names t -> Exprs t -> [C.Stmt]
    genoutstores ITupIgnore _ = []
    genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e]
    genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2
    genoutstores _ _ = error "wat"
compileFun _ _ = error "compileFun: Not single-argument function"

compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t)
compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr))

compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t
            -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
compileExp' aenv env = \case
    A.Let lhs rhs body -> do
        (names, env') <- genVarsEnv lhs env
        let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names]
        (usedA2, sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs
        (usedA3, res3) <- compileExp' aenv env' body
        return (usedA2 ++ usedA3
               ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3)

    A.Evar (Var _ idx) ->
        return ([], Right ([], ITupSingle (C.EVar (veprj env idx))))

    A.Pair a b -> do
        (usedA1, res1) <- compileExp' aenv env a
        (usedA2, res2) <- compileExp' aenv env b
        return (usedA1 ++ usedA2, Left (\case
            ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2
            ITupIgnore -> []
            ITupSingle _ -> error "wat"))

    A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e
    A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e
    A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e
    A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e
    A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e

    A.Shape (Var _ idx) ->
        let (shnames, _) = aveprj aenv idx
            buildExprs :: ShNames sh -> Exprs sh
            buildExprs ShZ = ITupIgnore
            buildExprs (ShS n names) = ITupPair (buildExprs names) (ITupSingle (C.EVar n))
        in return ([], Right ([], buildExprs shnames))

    A.ToIndex shr she idxe -> do
        let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr
            build ShapeRz _ _ = C.ELit "0"
            build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe'
            build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she'))
                                    (ITupPair idxes' (ITupSingle idxe')) =
                C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe'
            build _ _ _ = error "wat"
        (usedA1, res1) <- compileExp' aenv env she
        (sts1, shes) <- toStExprs (shapeType shr) res1
        (usedA2, res2) <- compileExp' aenv env idxe
        (sts2, idxes) <- toStExprs (shapeType shr) res2
        return (usedA1 ++ usedA2, Right (sts1 ++ sts2, ITupSingle (build shr shes idxes)))

    A.Index avar@(Var (ArrayR shr _) _) she ->
        compileExp' aenv env $
            A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she)

    A.LinearIndex (Var _ idx) e -> do
        temp <- genName "i"
        let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing]
        (usedA1, sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp))
                              <$> compileExp' aenv env e
        let (_, anames) = aveprj aenv idx
            usedA = itupList anames ++ usedA1
        return (usedA, Right (sts0 ++ sts1
                             ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames))

    _ -> throw "Unsupported Exp constructor"
  where
    binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b)
           -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
    binary aenv' env' op e' = do
        (usedA, res) <- compileExp' aenv' env' e'
        (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <-
            toStExprs (A.expType e') res
        return (usedA, Right (sts, ITupSingle (C.EOp e1 op e2)))

toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t)
toStExprs ty (Left fun) = do
    names <- genVars ty
    let sts1 = fun names
    return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names)
toStExprs _ (Right pair) = return pair

toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt]
toStoring (Left f) = f
toStoring (Right (sts, exs)) = (sts ++) . flip go exs
  where
    go :: Names t -> Exprs t -> [C.Stmt]
    go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex]
    go (ITupSingle _) _ = error "wat"
    go ITupIgnore _ = []
    go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2
    go (ITupPair _ _) _ = error "wat"

genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env')
genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env)
genVarsEnv (LeftHandSideSingle ty) env = do
    name <- genName "x"
    ty' <- cvtType ty
    return (ITupSingle (TypedName ty' name), VEPush name env)
genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do
    (n1, env1) <- genVarsEnv lhs1 env
    (n2, env2) <- genVarsEnv lhs2 env1
    return (ITupPair n1 n2, env2)

genVars :: TypeR t -> SC (Names t)
genVars TupRunit = return ITupIgnore
genVars (TupRsingle ty) = genVar ty
genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2

genVar :: ScalarType t -> SC (Names t)
genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x")