{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} module SC.Exp where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import qualified Language.C as C import SC.Defs import SC.Monad data CompiledFun aenv t1 t2 = CompiledFun C.FunDef -- ^ expression function implementation (Exprs t1 -> Names t2 -> [C.Expr]) -- ^ arguments builder. Given: -- - expressions that compute the direct arguments; -- - names that the output values should be stored in; -- returns the list of arguments to be passed to the compiled -- function. The outputs will be stored by storing to pointers to -- the given names. -- The arguments will refer to array variable names found in the -- original array environment. -- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith). compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2) compileFun aenv (A.Lam lhs (A.Body body)) = do funname <- genName "expfun_" (argnames, env) <- genVarsEnv lhs VENil outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n) <$> genVars (A.expType body) (usedA, res) <- compileExp' aenv env body (sts1, retexprs) <- toStExprs (A.expType body) res let sts2 = genoutstores outnames retexprs arguments = map (\(TypedAName t n) -> (t, n)) usedA ++ map (\(TypedName t n) -> (t, n)) (itupList argnames) ++ map (\(TypedName t n) -> (t, n)) (itupList outnames) return $ CompiledFun (C.ProcDef funname arguments (sts1 ++ sts2)) (\argexprs destnames -> map (\(TypedAName _ n) -> C.EVar n) usedA ++ itupList argexprs ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames)) where genoutstores :: Names t -> Exprs t -> [C.Stmt] genoutstores ITupIgnore _ = [] genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e] genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2 genoutstores _ _ = error "wat" compileFun _ _ = error "compileFun: Not single-argument function" compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t) compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr)) compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) compileExp' aenv env = \case A.Let lhs rhs body -> do (names, env') <- genVarsEnv lhs env let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] (usedA2, sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs (usedA3, res3) <- compileExp' aenv env' body return (usedA2 ++ usedA3 ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3) A.Evar (Var _ idx) -> return ([], Right ([], ITupSingle (C.EVar (veprj env idx)))) A.Pair a b -> do (usedA1, res1) <- compileExp' aenv env a (usedA2, res2) <- compileExp' aenv env b return (usedA1 ++ usedA2, Left (\case ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 ITupIgnore -> [] ITupSingle _ -> error "wat")) A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e A.Shape (Var _ idx) -> let (shnames, _) = aveprj aenv idx buildExprs :: ShNames sh -> Exprs sh buildExprs ShZ = ITupIgnore buildExprs (ShS n names) = ITupPair (buildExprs names) (ITupSingle (C.EVar n)) in return ([], Right ([], buildExprs shnames)) A.ToIndex shr she idxe -> do let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr build ShapeRz _ _ = C.ELit "0" build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe' build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she')) (ITupPair idxes' (ITupSingle idxe')) = C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe' build _ _ _ = error "wat" (usedA1, res1) <- compileExp' aenv env she (sts1, shes) <- toStExprs (shapeType shr) res1 (usedA2, res2) <- compileExp' aenv env idxe (sts2, idxes) <- toStExprs (shapeType shr) res2 return (usedA1 ++ usedA2, Right (sts1 ++ sts2, ITupSingle (build shr shes idxes))) A.Index avar@(Var (ArrayR shr _) _) she -> compileExp' aenv env $ A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she) A.LinearIndex (Var _ idx) e -> do temp <- genName "i" let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing] (usedA1, sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp)) <$> compileExp' aenv env e let (_, anames) = aveprj aenv idx usedA = itupList anames ++ usedA1 return (usedA, Right (sts0 ++ sts1 ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames)) _ -> throw "Unsupported Exp constructor" where binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b) -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) binary aenv' env' op e' = do (usedA, res) <- compileExp' aenv' env' e' (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <- toStExprs (A.expType e') res return (usedA, Right (sts, ITupSingle (C.EOp e1 op e2))) toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) toStExprs ty (Left fun) = do names <- genVars ty let sts1 = fun names return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) toStExprs _ (Right pair) = return pair toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] toStoring (Left f) = f toStoring (Right (sts, exs)) = (sts ++) . flip go exs where go :: Names t -> Exprs t -> [C.Stmt] go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] go (ITupSingle _) _ = error "wat" go ITupIgnore _ = [] go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 go (ITupPair _ _) _ = error "wat" genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env') genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env) genVarsEnv (LeftHandSideSingle ty) env = do name <- genName "x" ty' <- cvtType ty return (ITupSingle (TypedName ty' name), VEPush name env) genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsEnv lhs1 env (n2, env2) <- genVarsEnv lhs2 env1 return (ITupPair n1 n2, env2) genVars :: TypeR t -> SC (Names t) genVars TupRunit = return ITupIgnore genVars (TupRsingle ty) = genVar ty genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2 genVar :: ScalarType t -> SC (Names t) genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x")