{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} module SC.Exp where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import qualified Language.C as C import SC.Defs import SC.Monad data CompiledFun aenv t1 t2 = CompiledFun C.FunDef -- ^ expression function implementation (Exprs t1 -> Names t2 -> [C.Expr]) -- ^ arguments builder. Given: -- - expressions that compute the direct arguments; -- - names that the output values should be stored in; -- returns the list of arguments to be passed to the compiled -- function. The outputs will be stored by storing to pointers to -- the given names. -- The arguments will refer to array variable names found in the -- original array environment. [SomeArray] -- ^ Arrays that the constructed arguments use from the environment -- | The variable names corresponding to a single source-level array (before -- SoA conversion). data SomeArray = forall sh t. SomeArray (ShNames sh) (ANames t) -- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith). compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2) compileFun aenv (A.Lam lhs (A.Body body)) = do funname <- genName "expfun_" (argnames, env) <- genVarsEnv lhs VENil outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n) <$> genVars (A.expType body) (usedA, res) <- compileExp' aenv env body (sts1, retexprs) <- toStExprs (A.expType body) res let sts2 = genoutstores outnames retexprs arrayarguments = concatMap (\(SomeArray shn ans) -> map (\(TypedName t n) -> (t, n)) (shnamesList shn) ++ map (\(TypedAName t n) -> (t, n)) (itupList ans)) usedA arguments = arrayarguments ++ map (\(TypedName t n) -> (t, n)) (itupList argnames) ++ map (\(TypedName t n) -> (t, n)) (itupList outnames) return $ CompiledFun (C.ProcDef C.defAttrs { C.faStatic = True }funname arguments (sts1 ++ sts2)) (\argexprs destnames -> map (C.EVar . snd) arrayarguments ++ itupList argexprs ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames)) usedA where genoutstores :: Names t -> Exprs t -> [C.Stmt] genoutstores ITupIgnore _ = [] genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e] genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2 genoutstores _ _ = error "wat" compileFun _ _ = error "compileFun: Not single-argument function" compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t) compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr)) compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t -> SC ([SomeArray], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) compileExp' aenv env = \case A.Let lhs rhs body -> do (names, env') <- genVarsEnv lhs env let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] (usedA2, sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs (usedA3, res3) <- compileExp' aenv env' body return (usedA2 ++ usedA3 ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3) A.Evar (Var _ idx) -> return ([], Right ([], ITupSingle (C.EVar (veprj env idx)))) A.Nil -> return ([], Right ([], ITupIgnore)) A.Pair a b -> do (usedA1, res1) <- compileExp' aenv env a (usedA2, res2) <- compileExp' aenv env b return (usedA1 ++ usedA2 ,case (res1, res2) of (Right (sts1, exp1), Right (sts2, exp2)) -> Right (sts1 ++ sts2, ITupPair exp1 exp2) _ -> Left (\case ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 ITupIgnore -> [] ITupSingle _ -> error "wat")) A.Const ty x | Just str <- showExpConst ty x -> return ([], Right ([], ITupSingle (C.ELit str))) A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e A.Shape (Var _ idx) -> let (shnames, _) = aveprj aenv idx buildExprs :: ShNames sh -> Exprs sh buildExprs ShZ = ITupIgnore buildExprs (ShS names n) = ITupPair (buildExprs names) (ITupSingle (C.EVar n)) in return ([], Right ([], buildExprs shnames)) A.ToIndex shr she idxe -> do let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr build ShapeRz _ _ = C.ELit "0" build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe' build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she')) (ITupPair idxes' (ITupSingle idxe')) = C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe' build _ _ _ = error "wat" (usedA1, res1) <- compileExp' aenv env she (sts1, shes) <- toStExprs (shapeType shr) res1 (usedA2, res2) <- compileExp' aenv env idxe (sts2, idxes) <- toStExprs (shapeType shr) res2 return (usedA1 ++ usedA2, Right (sts1 ++ sts2, ITupSingle (build shr shes idxes))) A.Index avar@(Var (ArrayR shr _) _) she -> compileExp' aenv env $ A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she) A.LinearIndex (Var _ idx) e -> do temp <- genName "i" let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing] (usedA1, sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp)) <$> compileExp' aenv env e let (shnames, anames) = aveprj aenv idx usedA = SomeArray shnames anames : usedA1 return (usedA, Right (sts0 ++ sts1 ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames)) e -> throw $ "Unsupported Exp constructor: " ++ A.showExpOp e where binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b) -> SC ([SomeArray], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) binary aenv' env' op e' = do (usedA, res) <- compileExp' aenv' env' e' (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <- toStExprs (A.expType e') res return (usedA, Right (sts, ITupSingle (C.EOp e1 op e2))) toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) toStExprs ty (Left fun) = do names <- genVars ty let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] sts2 = fun names return (sts1 ++ sts2, itupmap (\(TypedName _ n) -> C.EVar n) names) toStExprs _ (Right pair) = return pair toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] toStoring (Left f) = f toStoring (Right (sts, exs)) = (sts ++) . flip go exs where go :: Names t -> Exprs t -> [C.Stmt] go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] go (ITupSingle _) _ = error "wat" go ITupIgnore _ = [] go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 go (ITupPair _ _) _ = error "wat" showExpConst :: ScalarType t -> t -> Maybe String showExpConst = \case SingleScalarType (NumSingleType (IntegralNumType it)) -> Just . goI it SingleScalarType (NumSingleType (FloatingNumType ft)) -> goF ft VectorScalarType _ -> const Nothing where goI :: IntegralType t -> t -> String goI TypeInt = (++ "LL") . show goI TypeInt8 = ("(int8_t)" ++) . show goI TypeInt16 = ("(int16_t)" ++) . show goI TypeInt32 = show goI TypeInt64 = (++ "LL") . show goI TypeWord = (++ "ULL") . show goI TypeWord8 = ("(uint8_t)" ++) . show goI TypeWord16 = ("(uint16_t)" ++) . show goI TypeWord32 = (++ "U") . show goI TypeWord64 = (++ "ULL") . show goF :: FloatingType t -> t -> Maybe String goF TypeHalf = const Nothing goF TypeFloat = Just . (++ "f") . show goF TypeDouble = Just . show genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env') genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env) genVarsEnv (LeftHandSideSingle ty) env = do name <- genName "x" ty' <- cvtType ty return (ITupSingle (TypedName ty' name), VEPush name env) genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsEnv lhs1 env (n2, env2) <- genVarsEnv lhs2 env1 return (ITupPair n1 n2, env2) genVars :: TypeR t -> SC (Names t) genVars TupRunit = return ITupIgnore genVars (TupRsingle ty) = genVar ty genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2 genVar :: ScalarType t -> SC (Names t) genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x")