summaryrefslogtreecommitdiff
path: root/SC/Exp.hs
blob: e24786c68a705c3048e5e8d09164aa828c8bce9a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
module SC.Exp where

import qualified Data.Array.Accelerate.AST as A
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Idx (idxToInt)
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Debug.Trace

import Debug
import qualified Language.C as C
import SC.Defs
import SC.Monad


data CompiledFun aenv t1 t2 =
    CompiledFun
        C.FunDef  -- ^ expression function implementation
        (Exprs t1 -> Names t2 -> [C.Expr])
            -- ^ arguments builder. Given:
            --   - expressions that compute the direct arguments;
            --   - names that the output values should be stored in;
            --   returns the list of arguments to be passed to the compiled
            --   function. The outputs will be stored by storing to pointers to
            --   the given names.
            --   The arguments will refer to array variable names found in the
            --   original array environment.
        [SomeArray]
            -- ^ Arrays that the constructed arguments use from the environment

-- | The variable names corresponding to a single source-level array (before
-- SoA conversion).
data SomeArray = forall sh t. SomeArray (ShNames sh) (ANames t)

-- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith).
compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2)
compileFun aenv (A.Lam lhs (A.Body body)) = do
    funname <- genName "expfun_"
    (argnames, env) <- genVarsEnv lhs VENil
    outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n)
                    <$> genVars (A.expType body)
    ((_tree, usedA), res) <- compileExp' aenv env body
    traceM ("Compiled expression:\n" ++ prettyTree "  " "  " _tree)
    (sts1, retexprs) <- toStExprs (A.expType body) res
    let sts2 = genoutstores outnames retexprs
        arrayarguments =
            concatMap (\(SomeArray shn ans) ->
                          map (\(TypedName t n) -> (t, n)) (shnamesList shn)
                            ++ map (\(TypedAName t n) -> (t, n)) (itupList ans))
                      usedA
        arguments =
            arrayarguments
            ++ map (\(TypedName t n) -> (t, n)) (itupList argnames)
            ++ map (\(TypedName t n) -> (t, n)) (itupList outnames)
    return $ CompiledFun
        (C.ProcDef C.defAttrs { C.faStatic = True }funname arguments (sts1 ++ sts2))
        (\argexprs destnames ->
            map (C.EVar . snd) arrayarguments
            ++ itupList argexprs
            ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames))
        usedA
  where
    genoutstores :: Names t -> Exprs t -> [C.Stmt]
    genoutstores ITupIgnore _ = []
    genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e]
    genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2
    genoutstores _ _ = error "wat"
compileFun _ _ = error "compileFun: Not single-argument function"

compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t)
compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr))

data Tree = Node String [Tree] | Leaf String

prettyTree :: String -> String -> Tree -> String
prettyTree pre _    (Leaf s) = pre ++ s ++ "\n"
prettyTree pre pre2 (Node s []) = prettyTree pre pre2 (Leaf s)
prettyTree pre pre2 (Node s ts) =
    let (ts1, t2) = (init ts, last ts)
    in pre ++ s ++ "\n" ++ concatMap (prettyTree (pre2 ++ "├─") (pre2 ++ "│ ")) ts1 ++ prettyTree (pre2 ++ "└─") (pre2 ++ "  ") t2

compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t
            -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
compileExp' aenv env = \case
    -- Foreign, IndexSlice, IndexFull, FromIndex, Case, Cond, While, PrimConst, ShapeSize, Undef, Coerce

    A.Let lhs rhs body -> do
        (names, env') <- genVarsEnv lhs env
        let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names]
        ((tree2, usedA2), sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs
        ((tree3, usedA3), res3) <- compileExp' aenv env' body
        return ((Node ("Let [" ++ show (length (itupList names)) ++ " vars]") [tree2, tree3], usedA2 ++ usedA3)
               ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3)

    A.Evar (Var _ idx) ->
        return ((Leaf ("Evar " ++ show (idxToInt idx)), []), Right ([], ITupSingle (C.EVar (veprj env idx))))

    A.Nil ->
        return ((Leaf "Nil", []), Right ([], ITupIgnore))

    A.Pair a b -> do
        ((tree1, usedA1), res1) <- compileExp' aenv env a
        ((tree2, usedA2), res2) <- compileExp' aenv env b
        return ((Node "Pair" [tree1, tree2], usedA1 ++ usedA2)
               ,case (res1, res2) of
                  (Right (sts1, exp1), Right (sts2, exp2)) ->
                      Right (sts1 ++ sts2, ITupPair exp1 exp2)
                  _ -> Left (\case
                                ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2
                                ITupIgnore -> []
                                ITupSingle _ -> error "wat"))

    A.While (A.Lam condlhs (A.Body condexp)) (A.Lam bodylhs (A.Body bodyexp)) initexp -> do
        names <- genVars (lhsToTupR condlhs)
        let condenv = pushVarsLHS condlhs names env
            bodyenv = pushVarsLHS condlhs names env
        ((tree1, usedA1), res1) <- compileExp' aenv env condexp
        ((tree2, usedA2), res2) <- compileExp' aenv env bodyexp
        ((tree3, usedA3), res3) <- compileExp' aenv env initexp
        undefined

    A.Const ty x
      | Just str <- showExpConst ty x
      -> return ((Leaf ("Const (" ++ str ++ ")"), []), Right ([], ITupSingle (C.ELit str)))

    A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e
    A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e
    A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e
    A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e
    A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e
    A.PrimApp (A.PrimFDiv _) e -> binary aenv env "/" e
    A.PrimApp (A.PrimLog TypeFloat) e -> unary aenv env "log" (C.ECall (C.Name "logf") . pure) e
    A.PrimApp (A.PrimLog TypeDouble) e -> unary aenv env "log" (C.ECall (C.Name "log") . pure) e
    A.PrimApp (A.PrimToFloating _ TypeFloat) e -> unary aenv env "cast float" (C.ECast C.TFloat) e
    A.PrimApp (A.PrimToFloating _ TypeDouble) e -> unary aenv env "cast double" (C.ECast C.TDouble) e
    A.PrimApp op _ -> throw $ "Unsupported Exp primitive operator: " ++ showPrimFun op

    A.Shape (Var _ idx) ->
        let (shnames, _) = aveprj aenv idx
            buildExprs :: ShNames sh -> Exprs sh
            buildExprs ShZ = ITupIgnore
            buildExprs (ShS names n) = ITupPair (buildExprs names) (ITupSingle (C.EVar n))
        in return ((Leaf ("Shape a" ++ show (idxToInt idx)), []), Right ([], buildExprs shnames))

    A.ToIndex shr she idxe -> do
        let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr
            build ShapeRz _ _ = C.ELit "0"
            build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe'
            build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she'))
                                    (ITupPair idxes' (ITupSingle idxe')) =
                C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe'
            build _ _ _ = error "wat"
        ((tree1, usedA1), res1) <- compileExp' aenv env she
        (sts1, shes) <- toStExprs (shapeType shr) res1
        ((tree2, usedA2), res2) <- compileExp' aenv env idxe
        (sts2, idxes) <- toStExprs (shapeType shr) res2
        return ((Node "ToIndex" [tree1, tree2], usedA1 ++ usedA2), Right (sts1 ++ sts2, ITupSingle (build shr shes idxes)))

    A.Index avar@(Var (ArrayR shr _) _) she ->
        compileExp' aenv env $
            A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she)

    A.LinearIndex (Var _ idx) e -> do
        temp <- genName "i"
        let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing]
        ((tree1, usedA1), sts1) <-
            fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp))
              <$> compileExp' aenv env e
        let (shnames, anames) = aveprj aenv idx
            usedA = SomeArray shnames anames : usedA1
        return ((Node ("LinearIndex a" ++ show (idxToInt idx)) [tree1], usedA)
               ,Right (sts0 ++ sts1
                      ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames))

    e -> throw $ "Unsupported Exp constructor: " ++ A.showExpOp e
  where
    binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b)
           -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
    binary aenv' env' op e' = do
        ((tree, usedA), res) <- compileExp' aenv' env' e'
        (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <-
            toStExprs (A.expType e') res
        return ((Node ("binary " ++ show op) [tree], usedA), Right (sts, ITupSingle (C.EOp e1 op e2)))

    unary :: AVarEnv aenv -> VarEnv env -> String -> (C.Expr -> C.Expr) -> A.OpenExp env aenv a
          -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
    unary aenv' env' name op e' = do
        ((tree, usedA), res) <- compileExp' aenv' env' e'
        (sts, ITupSingle e1) <- toStExprs (A.expType e') res
        return ((Node ("unary " ++ name) [tree], usedA), Right (sts, ITupSingle (op e1)))

toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t)
toStExprs ty (Left fun) = do
    names <- genVars ty
    let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names]
        sts2 = fun names
    return (sts1 ++ sts2, itupmap (\(TypedName _ n) -> C.EVar n) names)
toStExprs _ (Right pair) = return pair

toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt]
toStoring (Left f) = f
toStoring (Right (sts, exs)) = (sts ++) . flip go exs
  where
    go :: Names t -> Exprs t -> [C.Stmt]
    go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex]
    go (ITupSingle _) _ = error "wat"
    go ITupIgnore _ = []
    go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2
    go (ITupPair _ _) _ = error "wat"

showExpConst :: ScalarType t -> t -> Maybe String
showExpConst = \case
    SingleScalarType (NumSingleType (IntegralNumType it)) -> Just . goI it
    SingleScalarType (NumSingleType (FloatingNumType ft)) -> goF ft
    VectorScalarType _ -> const Nothing
  where
    goI :: IntegralType t -> t -> String
    goI TypeInt = (++ "LL") . show
    goI TypeInt8 = ("(int8_t)" ++) . show
    goI TypeInt16 = ("(int16_t)" ++) . show
    goI TypeInt32 = show
    goI TypeInt64 = (++ "LL") . show
    goI TypeWord = (++ "ULL") . show
    goI TypeWord8 = ("(uint8_t)" ++) . show
    goI TypeWord16 = ("(uint16_t)" ++) . show
    goI TypeWord32 = (++ "U") . show
    goI TypeWord64 = (++ "ULL") . show

    goF :: FloatingType t -> t -> Maybe String
    goF TypeHalf = const Nothing
    goF TypeFloat = Just . (++ "f") . show
    goF TypeDouble = Just . show

genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env')
genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env)
genVarsEnv (LeftHandSideSingle ty) env = do
    name <- genName "x"
    ty' <- cvtType ty
    return (ITupSingle (TypedName ty' name), VEPush name env)
genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do
    (n1, env1) <- genVarsEnv lhs1 env
    (n2, env2) <- genVarsEnv lhs2 env1
    return (ITupPair n1 n2, env2)

genVars :: TypeR t -> SC (Names t)
genVars TupRunit = return ITupIgnore
genVars (TupRsingle ty) = genVar ty
genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2

genVar :: ScalarType t -> SC (Names t)
genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x")