diff options
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 112 | 
1 files changed, 55 insertions, 57 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 83c25c3..0db0d0f 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,12 +1,14 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE GADTs #-}  {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE TypeApplications #-}  module Compile where  import Control.Monad.Trans.State.Strict  import Data.Foldable (toList)  import Data.Functor.Const +import qualified Data.Functor.Product as Product  import Data.List (intersperse, intercalate)  import qualified Data.Map.Strict as Map  import Data.Map.Strict (Map) @@ -26,7 +28,7 @@ data StructDecl = StructDecl    deriving (Show)  data Stmt -  = SVarDecl String String CExpr  -- ^ type, variable name, right-hand side +  = SVarDecl Bool String String CExpr  -- ^ const, type, variable name, right-hand side    | SVarDeclUninit String String  -- ^ type, variable name (no initialiser)    | SAsg String CExpr  -- ^ variable name, right-hand side    | SBlock [Stmt] @@ -50,7 +52,7 @@ printStructDecl (StructDecl name contents comment) =  printStmt :: Int -> Stmt -> ShowS  printStmt indent = \case -  SVarDecl typ name rhs -> showString (typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" +  SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";"    SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")    SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";"    SBlock stmts -> @@ -92,15 +94,15 @@ genStructName = \t -> "ty_" ++ gen t where    -- all tags start with a letter, so the array mangling is unambiguous.    gen :: STy t -> String    gen STNil = "n" -  gen (STPair a b) = 'p' : gen a ++ gen b -  gen (STEither a b) = 'e' : gen a ++ gen b -  gen (STMaybe t) = 'm' : gen t +  gen (STPair a b) = 'P' : gen a ++ gen b +  gen (STEither a b) = 'E' : gen a ++ gen b +  gen (STMaybe t) = 'M' : gen t    gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t    gen (STScal st) = case st of -    STI32 -> "i4" -    STI64 -> "i8" -    STF32 -> "f4" -    STF64 -> "f8" +    STI32 -> "i" +    STI64 -> "j" +    STF32 -> "f" +    STF64 -> "d"      STBool -> "b"    gen (STAccum t) = 'C' : gen t @@ -110,20 +112,20 @@ genStruct topty = case topty of      Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com)    STPair a b ->      let name = genStructName (STPair a b) -    in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;") com) +    in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com)    STEither a b ->      let name = genStructName (STEither a b)  -- 0 -> a, 1 -> b -    in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };") com) +    in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com)    STMaybe t ->      let name = genStructName (STMaybe t)  -- 0 -> nothing, 1 -> just -    in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;") com) +    in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com)    STArr n t ->      let name = genStructName (STArr n t) -    in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;") com) +    in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com)    STScal _ -> mempty    STAccum t ->      let name = genStructName (STAccum t) -    in Map.singleton name (StructDecl name (genStructName t ++ " a;") com) +    in Map.singleton name (StructDecl name (repTy t ++ " a;") com)           <> genStruct t    where      com = ppTy 0 topty @@ -157,13 +159,20 @@ emitStruct ty = do    modify $ \s -> s { csStructs = genStruct ty <> csStructs s }    return (genStructName ty) -compile :: SList (Const String) env -> Ex env t -> String +nameEnv :: SList f env -> SList (Const String) env +nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) + +compile :: SList STy env -> Ex env t -> String  compile env expr = -  let (res, s) = runState (compile' env expr) (CompState mempty mempty 1) +  let args = nameEnv env +      (res, s) = runState (compile' args expr) (CompState mempty mempty 1)    in ($ "") $ compose         [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s))         ,showString "\n" -       ,showString (genStructName (typeOf expr) ++ " kernel(" ++ intercalate ", " (reverse (unSList getConst env)) ++ ") {\n") +       ,showString $ +          repTy (typeOf expr) ++ " kernel(" ++ +            intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++ +            ") {\n"         ,compose $ map (\st -> showString "  " . printStmt 1 st . showString "\n") (toList (csStmts s))         ,showString ("  return ") . printCExpr res . showString ";\n}\n"] @@ -174,7 +183,7 @@ compile' env = \case    ELet _ rhs body -> do      e <- compile' env rhs      var <- genName -    emit $ SVarDecl (genStructName (typeOf rhs)) var e +    emit $ SVarDecl True (repTy (typeOf rhs)) var e      compile' (Const var `SCons` env) body    EPair _ a b -> do @@ -201,6 +210,17 @@ compile' env = \case      e2 <- compile' env e      return $ CEStruct name [("tag", CELit "1"), ("b", e2)] +  ECase _ (EOp _ OIf e) a b -> do +    e1 <- compile' env e +    (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a  -- don't access that nil, stupid you +    (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b +    retvar <- genName +    emit $ SVarDeclUninit (repTy (typeOf a)) retvar +    emit $ SIf e1 +             (stmts2 <> pure (SAsg retvar e2)) +             (stmts3 <> pure (SAsg retvar e3)) +    return (CELit retvar) +    ECase _ e a b -> do      let STEither t1 t2 = typeOf e      e1 <- compile' env e @@ -209,14 +229,14 @@ compile' env = \case      (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a      (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b      retvar <- genName -    emit $ SVarDeclUninit (genStructName (typeOf a)) retvar -    emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) +    emit $ SVarDeclUninit (repTy (typeOf a)) retvar +    emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1)                  <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) -                           (pure (SVarDecl (genStructName t1) fieldvar +                           (pure (SVarDecl True (repTy t1) fieldvar                                             (CEProj (CELit var) "a"))                              <> stmts2                              <> pure (SAsg retvar e2)) -                           (pure (SVarDecl (genStructName t2) fieldvar +                           (pure (SVarDecl True (repTy t2) fieldvar                                             (CEProj (CELit var) "b"))                              <> stmts3                              <> pure (SAsg retvar e3)))) @@ -238,12 +258,12 @@ compile' env = \case      (e2, stmts2) <- scope $ compile' env a      (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b      retvar <- genName -    emit $ SVarDeclUninit (genStructName (typeOf a)) retvar -    emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) +    emit $ SVarDeclUninit (repTy (typeOf a)) retvar +    emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1)                  <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))                             (stmts2                              <> pure (SAsg retvar e2)) -                           (pure (SVarDecl (genStructName (typeOf b)) fieldvar +                           (pure (SVarDecl True (repTy (typeOf b)) fieldvar                                             (CEProj (CELit var) "a"))                              <> stmts3                              <> pure (SAsg retvar e3)))) @@ -282,36 +302,14 @@ compile' env = \case    EShape _ e -> error "TODO" -- EShape ext (compile' e) +  EOp _ op (EPair _ e1 e2) -> do +    e1' <- compile' env e1 +    e2' <- compile' env e2 +    compileOpPair op e1' e2' +    EOp _ op e -> do -    e1 <- compile' env e -    let unary cop = return @(State CompState) $ CECall cop [e1] -    let binary cop = do -          name <- genName -          emit $ SVarDecl (genStructName (typeOf e)) name e1 -          return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") -    case op of -      OAdd _ -> binary "+" -      OMul _ -> binary "*" -      ONeg _ -> unary "-" -      OLt _ -> binary "<" -      OLe _ -> binary "<=" -      OEq _ -> binary "==" -      ONot -> unary "!" -      OAnd -> binary "&&" -      OOr -> binary "||" -      OIf -> do -        name <- emitStruct (STEither STNil STNil) -        _ <- emitStruct STNil -        return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) -                         (CEStruct name [("tag", CELit "1")]) -      ORound64 -> unary "(int64_t)round"  -- ew -      OToFl64 -> unary "(double)" -      ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 -      OExp STF32 -> unary "expf" -      OExp STF64 -> unary "exp" -      OLog STF32 -> unary "logf" -      OLog STF64 -> unary "log" -      OIDiv _ -> binary "/" +    e' <- compile' env e +    compileOpGeneral op e'    ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) @@ -334,7 +332,7 @@ compileOpGeneral op e1 = do    let unary cop = return @(State CompState) $ CECall cop [e1]    let binary cop = do          name <- genName -        emit $ SVarDecl (genStructName (opt1 op)) name e1 +        emit $ SVarDecl True (repTy (opt1 op)) name e1          return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")    case op of      OAdd _ -> binary "+" | 
