summaryrefslogtreecommitdiff
path: root/SC/Exp.hs
diff options
context:
space:
mode:
Diffstat (limited to 'SC/Exp.hs')
-rw-r--r--SC/Exp.hs170
1 files changed, 170 insertions, 0 deletions
diff --git a/SC/Exp.hs b/SC/Exp.hs
new file mode 100644
index 0000000..d033cc8
--- /dev/null
+++ b/SC/Exp.hs
@@ -0,0 +1,170 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+module SC.Exp where
+
+import qualified Data.Array.Accelerate.AST as A
+import Data.Array.Accelerate.AST.LeftHandSide
+import Data.Array.Accelerate.AST.Var
+import Data.Array.Accelerate.Representation.Array
+import Data.Array.Accelerate.Representation.Shape
+import Data.Array.Accelerate.Representation.Type
+import Data.Array.Accelerate.Type
+
+import qualified Language.C as C
+import SC.Defs
+import SC.Monad
+
+
+data CompiledFun aenv t1 t2 =
+ CompiledFun
+ C.FunDef -- ^ expression function implementation
+ (Exprs t1 -> Names t2 -> [C.Expr])
+ -- ^ arguments builder. Given:
+ -- - expressions that compute the direct arguments;
+ -- - names that the output values should be stored in;
+ -- returns the list of arguments to be passed to the compiled
+ -- function. The outputs will be stored by storing to pointers to
+ -- the given names.
+ -- The arguments will refer to array variable names found in the
+ -- original array environment.
+
+-- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith).
+compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2)
+compileFun aenv (A.Lam lhs (A.Body body)) = do
+ funname <- genName "expfun_"
+ (argnames, env) <- genVarsEnv lhs VENil
+ outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n)
+ <$> genVars (A.expType body)
+ (usedA, res) <- compileExp' aenv env body
+ (sts1, retexprs) <- toStExprs (A.expType body) res
+ let sts2 = genoutstores outnames retexprs
+ arguments =
+ map (\(TypedAName t n) -> (t, n)) usedA
+ ++ map (\(TypedName t n) -> (t, n)) (itupList argnames)
+ ++ map (\(TypedName t n) -> (t, n)) (itupList outnames)
+ return $ CompiledFun
+ (C.ProcDef funname arguments (sts1 ++ sts2))
+ (\argexprs destnames ->
+ map (\(TypedAName _ n) -> C.EVar n) usedA
+ ++ itupList argexprs
+ ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames))
+ where
+ genoutstores :: Names t -> Exprs t -> [C.Stmt]
+ genoutstores ITupIgnore _ = []
+ genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e]
+ genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2
+ genoutstores _ _ = error "wat"
+compileFun _ _ = error "compileFun: Not single-argument function"
+
+compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t)
+compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr))
+
+compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t
+ -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
+compileExp' aenv env = \case
+ A.Let lhs rhs body -> do
+ (names, env') <- genVarsEnv lhs env
+ let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names]
+ (usedA2, sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs
+ (usedA3, res3) <- compileExp' aenv env' body
+ return (usedA2 ++ usedA3
+ ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3)
+
+ A.Evar (Var _ idx) ->
+ return ([], Right ([], ITupSingle (C.EVar (veprj env idx))))
+
+ A.Pair a b -> do
+ (usedA1, res1) <- compileExp' aenv env a
+ (usedA2, res2) <- compileExp' aenv env b
+ return (usedA1 ++ usedA2, Left (\case
+ ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2
+ ITupIgnore -> []
+ ITupSingle _ -> error "wat"))
+
+ A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e
+ A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e
+ A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e
+ A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e
+ A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e
+
+ A.Shape (Var _ idx) ->
+ let (shnames, _) = aveprj aenv idx
+ buildExprs :: ShNames sh -> Exprs sh
+ buildExprs ShZ = ITupIgnore
+ buildExprs (ShS n names) = ITupPair (buildExprs names) (ITupSingle (C.EVar n))
+ in return ([], Right ([], buildExprs shnames))
+
+ A.ToIndex shr she idxe -> do
+ let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr
+ build ShapeRz _ _ = C.ELit "0"
+ build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe'
+ build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she'))
+ (ITupPair idxes' (ITupSingle idxe')) =
+ C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe'
+ build _ _ _ = error "wat"
+ (usedA1, res1) <- compileExp' aenv env she
+ (sts1, shes) <- toStExprs (shapeType shr) res1
+ (usedA2, res2) <- compileExp' aenv env idxe
+ (sts2, idxes) <- toStExprs (shapeType shr) res2
+ return (usedA1 ++ usedA2, Right (sts1 ++ sts2, ITupSingle (build shr shes idxes)))
+
+ A.Index avar@(Var (ArrayR shr _) _) she ->
+ compileExp' aenv env $
+ A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she)
+
+ A.LinearIndex (Var _ idx) e -> do
+ temp <- genName "i"
+ let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing]
+ (usedA1, sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp))
+ <$> compileExp' aenv env e
+ let (_, anames) = aveprj aenv idx
+ usedA = itupList anames ++ usedA1
+ return (usedA, Right (sts0 ++ sts1
+ ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames))
+
+ _ -> throw "Unsupported Exp constructor"
+ where
+ binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b)
+ -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t))
+ binary aenv' env' op e' = do
+ (usedA, res) <- compileExp' aenv' env' e'
+ (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <-
+ toStExprs (A.expType e') res
+ return (usedA, Right (sts, ITupSingle (C.EOp e1 op e2)))
+
+toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t)
+toStExprs ty (Left fun) = do
+ names <- genVars ty
+ let sts1 = fun names
+ return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names)
+toStExprs _ (Right pair) = return pair
+
+toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt]
+toStoring (Left f) = f
+toStoring (Right (sts, exs)) = (sts ++) . flip go exs
+ where
+ go :: Names t -> Exprs t -> [C.Stmt]
+ go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex]
+ go (ITupSingle _) _ = error "wat"
+ go ITupIgnore _ = []
+ go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2
+ go (ITupPair _ _) _ = error "wat"
+
+genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env')
+genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env)
+genVarsEnv (LeftHandSideSingle ty) env = do
+ name <- genName "x"
+ ty' <- cvtType ty
+ return (ITupSingle (TypedName ty' name), VEPush name env)
+genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do
+ (n1, env1) <- genVarsEnv lhs1 env
+ (n2, env2) <- genVarsEnv lhs2 env1
+ return (ITupPair n1 n2, env2)
+
+genVars :: TypeR t -> SC (Names t)
+genVars TupRunit = return ITupIgnore
+genVars (TupRsingle ty) = genVar ty
+genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2
+
+genVar :: ScalarType t -> SC (Names t)
+genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x")