{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} module SC.Exp where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Idx (idxToInt) import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Debug.Trace import Debug import qualified Language.C as C import SC.Defs import SC.Monad data CompiledFun aenv t1 t2 = CompiledFun C.FunDef -- ^ expression function implementation (Exprs t1 -> Names t2 -> [C.Expr]) -- ^ arguments builder. Given: -- - expressions that compute the direct arguments; -- - names that the output values should be stored in; -- returns the list of arguments to be passed to the compiled -- function. The outputs will be stored by storing to pointers to -- the given names. -- The arguments will refer to array variable names found in the -- original array environment. [SomeArray] -- ^ Arrays that the constructed arguments use from the environment -- | The variable names corresponding to a single source-level array (before -- SoA conversion). data SomeArray = forall sh t. SomeArray (ShNames sh) (ANames t) -- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith). compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2) compileFun aenv (A.Lam lhs (A.Body body)) = do funname <- genName "expfun_" (argnames, env) <- genVarsEnv lhs VENil outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n) <$> genVars (A.expType body) ((_tree, usedA), res) <- compileExp' aenv env body traceM ("Compiled expression:\n" ++ prettyTree " " " " _tree) (sts1, retexprs) <- toStExprs (A.expType body) res let sts2 = genoutstores outnames retexprs arrayarguments = concatMap (\(SomeArray shn ans) -> map (\(TypedName t n) -> (t, n)) (shnamesList shn) ++ map (\(TypedAName t n) -> (t, n)) (itupList ans)) usedA arguments = arrayarguments ++ map (\(TypedName t n) -> (t, n)) (itupList argnames) ++ map (\(TypedName t n) -> (t, n)) (itupList outnames) return $ CompiledFun (C.ProcDef C.defAttrs { C.faStatic = True }funname arguments (sts1 ++ sts2)) (\argexprs destnames -> map (C.EVar . snd) arrayarguments ++ itupList argexprs ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames)) usedA where genoutstores :: Names t -> Exprs t -> [C.Stmt] genoutstores ITupIgnore _ = [] genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e] genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2 genoutstores _ _ = error "wat" compileFun _ _ = error "compileFun: Not single-argument function" compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t) compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr)) data Tree = Node String [Tree] | Leaf String prettyTree :: String -> String -> Tree -> String prettyTree pre _ (Leaf s) = pre ++ s ++ "\n" prettyTree pre pre2 (Node s []) = prettyTree pre pre2 (Leaf s) prettyTree pre pre2 (Node s ts) = let (ts1, t2) = (init ts, last ts) in pre ++ s ++ "\n" ++ concatMap (prettyTree (pre2 ++ "├─") (pre2 ++ "│ ")) ts1 ++ prettyTree (pre2 ++ "└─") (pre2 ++ " ") t2 compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) compileExp' aenv env = \case -- Foreign, IndexSlice, IndexFull, FromIndex, Case, Cond, While, PrimConst, ShapeSize, Undef, Coerce A.Let lhs rhs body -> do (names, env') <- genVarsEnv lhs env let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] ((tree2, usedA2), sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs ((tree3, usedA3), res3) <- compileExp' aenv env' body return ((Node ("Let [" ++ show (length (itupList names)) ++ " vars]") [tree2, tree3], usedA2 ++ usedA3) ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3) A.Evar (Var _ idx) -> return ((Leaf ("Evar " ++ show (idxToInt idx)), []), Right ([], ITupSingle (C.EVar (veprj env idx)))) A.Nil -> return ((Leaf "Nil", []), Right ([], ITupIgnore)) A.Pair a b -> do ((tree1, usedA1), res1) <- compileExp' aenv env a ((tree2, usedA2), res2) <- compileExp' aenv env b return ((Node "Pair" [tree1, tree2], usedA1 ++ usedA2) ,case (res1, res2) of (Right (sts1, exp1), Right (sts2, exp2)) -> Right (sts1 ++ sts2, ITupPair exp1 exp2) _ -> Left (\case ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 ITupIgnore -> [] ITupSingle _ -> error "wat")) A.While (A.Lam condlhs (A.Body condexp)) (A.Lam bodylhs (A.Body bodyexp)) initexp -> do names <- genVars (lhsToTupR condlhs) let condenv = pushVarsLHS condlhs names env bodyenv = pushVarsLHS condlhs names env ((tree1, usedA1), res1) <- compileExp' aenv env condexp ((tree2, usedA2), res2) <- compileExp' aenv env bodyexp ((tree3, usedA3), res3) <- compileExp' aenv env initexp undefined A.Const ty x | Just str <- showExpConst ty x -> return ((Leaf ("Const (" ++ str ++ ")"), []), Right ([], ITupSingle (C.ELit str))) A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e A.PrimApp (A.PrimFDiv _) e -> binary aenv env "/" e A.PrimApp (A.PrimLog TypeFloat) e -> unary aenv env "log" (C.ECall (C.Name "logf") . pure) e A.PrimApp (A.PrimLog TypeDouble) e -> unary aenv env "log" (C.ECall (C.Name "log") . pure) e A.PrimApp (A.PrimToFloating _ TypeFloat) e -> unary aenv env "cast float" (C.ECast C.TFloat) e A.PrimApp (A.PrimToFloating _ TypeDouble) e -> unary aenv env "cast double" (C.ECast C.TDouble) e A.PrimApp op _ -> throw $ "Unsupported Exp primitive operator: " ++ showPrimFun op A.Shape (Var _ idx) -> let (shnames, _) = aveprj aenv idx buildExprs :: ShNames sh -> Exprs sh buildExprs ShZ = ITupIgnore buildExprs (ShS names n) = ITupPair (buildExprs names) (ITupSingle (C.EVar n)) in return ((Leaf ("Shape a" ++ show (idxToInt idx)), []), Right ([], buildExprs shnames)) A.ToIndex shr she idxe -> do let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr build ShapeRz _ _ = C.ELit "0" build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe' build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she')) (ITupPair idxes' (ITupSingle idxe')) = C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe' build _ _ _ = error "wat" ((tree1, usedA1), res1) <- compileExp' aenv env she (sts1, shes) <- toStExprs (shapeType shr) res1 ((tree2, usedA2), res2) <- compileExp' aenv env idxe (sts2, idxes) <- toStExprs (shapeType shr) res2 return ((Node "ToIndex" [tree1, tree2], usedA1 ++ usedA2), Right (sts1 ++ sts2, ITupSingle (build shr shes idxes))) A.Index avar@(Var (ArrayR shr _) _) she -> compileExp' aenv env $ A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she) A.LinearIndex (Var _ idx) e -> do temp <- genName "i" let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing] ((tree1, usedA1), sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp)) <$> compileExp' aenv env e let (shnames, anames) = aveprj aenv idx usedA = SomeArray shnames anames : usedA1 return ((Node ("LinearIndex a" ++ show (idxToInt idx)) [tree1], usedA) ,Right (sts0 ++ sts1 ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames)) e -> throw $ "Unsupported Exp constructor: " ++ A.showExpOp e where binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b) -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) binary aenv' env' op e' = do ((tree, usedA), res) <- compileExp' aenv' env' e' (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <- toStExprs (A.expType e') res return ((Node ("binary " ++ show op) [tree], usedA), Right (sts, ITupSingle (C.EOp e1 op e2))) unary :: AVarEnv aenv -> VarEnv env -> String -> (C.Expr -> C.Expr) -> A.OpenExp env aenv a -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) unary aenv' env' name op e' = do ((tree, usedA), res) <- compileExp' aenv' env' e' (sts, ITupSingle e1) <- toStExprs (A.expType e') res return ((Node ("unary " ++ name) [tree], usedA), Right (sts, ITupSingle (op e1))) toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) toStExprs ty (Left fun) = do names <- genVars ty let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] sts2 = fun names return (sts1 ++ sts2, itupmap (\(TypedName _ n) -> C.EVar n) names) toStExprs _ (Right pair) = return pair toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] toStoring (Left f) = f toStoring (Right (sts, exs)) = (sts ++) . flip go exs where go :: Names t -> Exprs t -> [C.Stmt] go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] go (ITupSingle _) _ = error "wat" go ITupIgnore _ = [] go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 go (ITupPair _ _) _ = error "wat" showExpConst :: ScalarType t -> t -> Maybe String showExpConst = \case SingleScalarType (NumSingleType (IntegralNumType it)) -> Just . goI it SingleScalarType (NumSingleType (FloatingNumType ft)) -> goF ft VectorScalarType _ -> const Nothing where goI :: IntegralType t -> t -> String goI TypeInt = (++ "LL") . show goI TypeInt8 = ("(int8_t)" ++) . show goI TypeInt16 = ("(int16_t)" ++) . show goI TypeInt32 = show goI TypeInt64 = (++ "LL") . show goI TypeWord = (++ "ULL") . show goI TypeWord8 = ("(uint8_t)" ++) . show goI TypeWord16 = ("(uint16_t)" ++) . show goI TypeWord32 = (++ "U") . show goI TypeWord64 = (++ "ULL") . show goF :: FloatingType t -> t -> Maybe String goF TypeHalf = const Nothing goF TypeFloat = Just . (++ "f") . show goF TypeDouble = Just . show genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env') genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env) genVarsEnv (LeftHandSideSingle ty) env = do name <- genName "x" ty' <- cvtType ty return (ITupSingle (TypedName ty' name), VEPush name env) genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsEnv lhs1 env (n2, env2) <- genVarsEnv lhs2 env1 return (ITupPair n1 n2, env2) genVars :: TypeR t -> SC (Names t) genVars TupRunit = return ITupIgnore genVars (TupRsingle ty) = genVar ty genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2 genVar :: ScalarType t -> SC (Names t) genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x")