diff options
Diffstat (limited to 'SC/Exp.hs')
-rw-r--r-- | SC/Exp.hs | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/SC/Exp.hs b/SC/Exp.hs new file mode 100644 index 0000000..d033cc8 --- /dev/null +++ b/SC/Exp.hs @@ -0,0 +1,170 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +module SC.Exp where + +import qualified Data.Array.Accelerate.AST as A +import Data.Array.Accelerate.AST.LeftHandSide +import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Representation.Shape +import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Type + +import qualified Language.C as C +import SC.Defs +import SC.Monad + + +data CompiledFun aenv t1 t2 = + CompiledFun + C.FunDef -- ^ expression function implementation + (Exprs t1 -> Names t2 -> [C.Expr]) + -- ^ arguments builder. Given: + -- - expressions that compute the direct arguments; + -- - names that the output values should be stored in; + -- returns the list of arguments to be passed to the compiled + -- function. The outputs will be stored by storing to pointers to + -- the given names. + -- The arguments will refer to array variable names found in the + -- original array environment. + +-- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith). +compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2) +compileFun aenv (A.Lam lhs (A.Body body)) = do + funname <- genName "expfun_" + (argnames, env) <- genVarsEnv lhs VENil + outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n) + <$> genVars (A.expType body) + (usedA, res) <- compileExp' aenv env body + (sts1, retexprs) <- toStExprs (A.expType body) res + let sts2 = genoutstores outnames retexprs + arguments = + map (\(TypedAName t n) -> (t, n)) usedA + ++ map (\(TypedName t n) -> (t, n)) (itupList argnames) + ++ map (\(TypedName t n) -> (t, n)) (itupList outnames) + return $ CompiledFun + (C.ProcDef funname arguments (sts1 ++ sts2)) + (\argexprs destnames -> + map (\(TypedAName _ n) -> C.EVar n) usedA + ++ itupList argexprs + ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames)) + where + genoutstores :: Names t -> Exprs t -> [C.Stmt] + genoutstores ITupIgnore _ = [] + genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e] + genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2 + genoutstores _ _ = error "wat" +compileFun _ _ = error "compileFun: Not single-argument function" + +compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t) +compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr)) + +compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t + -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) +compileExp' aenv env = \case + A.Let lhs rhs body -> do + (names, env') <- genVarsEnv lhs env + let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] + (usedA2, sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs + (usedA3, res3) <- compileExp' aenv env' body + return (usedA2 ++ usedA3 + ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3) + + A.Evar (Var _ idx) -> + return ([], Right ([], ITupSingle (C.EVar (veprj env idx)))) + + A.Pair a b -> do + (usedA1, res1) <- compileExp' aenv env a + (usedA2, res2) <- compileExp' aenv env b + return (usedA1 ++ usedA2, Left (\case + ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 + ITupIgnore -> [] + ITupSingle _ -> error "wat")) + + A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e + A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e + A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e + A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e + A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e + + A.Shape (Var _ idx) -> + let (shnames, _) = aveprj aenv idx + buildExprs :: ShNames sh -> Exprs sh + buildExprs ShZ = ITupIgnore + buildExprs (ShS n names) = ITupPair (buildExprs names) (ITupSingle (C.EVar n)) + in return ([], Right ([], buildExprs shnames)) + + A.ToIndex shr she idxe -> do + let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr + build ShapeRz _ _ = C.ELit "0" + build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe' + build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she')) + (ITupPair idxes' (ITupSingle idxe')) = + C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe' + build _ _ _ = error "wat" + (usedA1, res1) <- compileExp' aenv env she + (sts1, shes) <- toStExprs (shapeType shr) res1 + (usedA2, res2) <- compileExp' aenv env idxe + (sts2, idxes) <- toStExprs (shapeType shr) res2 + return (usedA1 ++ usedA2, Right (sts1 ++ sts2, ITupSingle (build shr shes idxes))) + + A.Index avar@(Var (ArrayR shr _) _) she -> + compileExp' aenv env $ + A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she) + + A.LinearIndex (Var _ idx) e -> do + temp <- genName "i" + let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing] + (usedA1, sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp)) + <$> compileExp' aenv env e + let (_, anames) = aveprj aenv idx + usedA = itupList anames ++ usedA1 + return (usedA, Right (sts0 ++ sts1 + ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames)) + + _ -> throw "Unsupported Exp constructor" + where + binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b) + -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) + binary aenv' env' op e' = do + (usedA, res) <- compileExp' aenv' env' e' + (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <- + toStExprs (A.expType e') res + return (usedA, Right (sts, ITupSingle (C.EOp e1 op e2))) + +toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) +toStExprs ty (Left fun) = do + names <- genVars ty + let sts1 = fun names + return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) +toStExprs _ (Right pair) = return pair + +toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] +toStoring (Left f) = f +toStoring (Right (sts, exs)) = (sts ++) . flip go exs + where + go :: Names t -> Exprs t -> [C.Stmt] + go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] + go (ITupSingle _) _ = error "wat" + go ITupIgnore _ = [] + go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 + go (ITupPair _ _) _ = error "wat" + +genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env') +genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env) +genVarsEnv (LeftHandSideSingle ty) env = do + name <- genName "x" + ty' <- cvtType ty + return (ITupSingle (TypedName ty' name), VEPush name env) +genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do + (n1, env1) <- genVarsEnv lhs1 env + (n2, env2) <- genVarsEnv lhs2 env1 + return (ITupPair n1 n2, env2) + +genVars :: TypeR t -> SC (Names t) +genVars TupRunit = return ITupIgnore +genVars (TupRsingle ty) = genVar ty +genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2 + +genVar :: ScalarType t -> SC (Names t) +genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x") |