diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-01-27 23:08:14 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-01-27 23:08:14 +0100 |
commit | f0150137969758ee7255ade3c90db915bc8542df (patch) | |
tree | d8b1ce58a503d39af08028a0cb90b912531c0532 | |
parent | d33a19b58de5af8017e1ae5fae06a9378379e20c (diff) |
Ramblings in Compile
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/Compile.hs | 54 |
2 files changed, 25 insertions, 30 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 66452d9..c38c270 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -26,6 +26,7 @@ library base >= 4.14 && < 4.19, containers, template-haskell, + transformers, some hs-source-dirs: src diff --git a/src/Compile.hs b/src/Compile.hs index 2fcff5d..749368a 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -4,14 +4,17 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Compile where -import Control.Monad (ap) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Control.Monad.Trans.Writer.CPS import Data.Kind (Type) -import Data.Map.Strict (Map) -import qualified Data.Map.Strict as Map import AST +import Data data Body = Body [Stm] Inline -- body, return expr @@ -48,35 +51,23 @@ prj = \env idx -> go idx env go IZ (EPush v _) = v go (IS i) (EPush _ env) = go i env --- generated global function definitions, generated local statements, function typedef cache (name, decl) -newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a)) - deriving (Functor) -instance Applicative M where - pure x = M (\i m -> ([], [], m, i, x)) - (<*>) = ap -instance Monad M where - M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m - (d2, s2, m2, i2, y) = let M h = g x in h i1 m1 - in (d1 <> d2, s1 <> s2, m2, i2, y)) +newtype M a = M (StateT Int -- ID generator + (WriterT [FunDef] -- generated global function definitions + (Writer [Stm])) -- generated local statements + a) + deriving newtype (Functor, Applicative, Monad) emitFun :: FunDef -> M () -emitFun fd = M (\i m -> ([fd], [], m, i, ())) +emitFun fd = M (lift (tell [fd])) emitStm :: Stm -> M () -emitStm stm = M (\i m -> ([], [stm], m, i, ())) +emitStm stm = M (lift (lift (tell [stm]))) captureStms :: M a -> M ([Stm], a) -captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m - in (d, [], m2, i2, (s, x))) +captureStms (M m) = M (mapStateT (mapWriterT (mapWriter (\(((x, i), fds), stms) -> ((((stms, x), i), fds), [])))) m) genId :: M Int -genId = M (\i m -> ([], [], m, i + 1, i)) - -getTypedef :: Ty -> M (Maybe String) -getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m) - -putTypedef :: Ty -> String -> String -> M () -putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ()) +genId = M (state (\i -> (i, i + 1))) genName :: String -> M String genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId @@ -85,8 +76,7 @@ genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId c : _ | c `elem` "0123456789_" -> ("_", "") | otherwise -> ("", "") --- Function values are returned as a function-pointer-typed expression -compile :: Target -> Env env String -> Expr x env t -> M Inline +compile :: Target -> Env env String -> Ex env t -> M Inline compile tgt env = \case EVar _ _ i -> pure $ IVar (prj env i) ELet _ rhs e -> do @@ -96,17 +86,21 @@ compile tgt env = \case emitStm $ VarDef rhsty var (Just rhsi) compile tgt (EPush var env) e - EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e - EBuild x n k e -> case tgt of + EBuild1 x k e -> compile tgt env $ EBuild x (k :< VNil) e + EBuild _ VNil e -> _ + EBuild _ k e -> case tgt of Host -> do fname <- genName "buildfun" - let n' = fromNat (unSNat n) + let n' = fromSNat (vecLength k) shapevars = ['s' : show i | i <- [0 .. n' - 1]] + _ = foldr (\a b -> EOp ext (OMul STI64) (EPair ext a b)) (EConst ext STI64 1) k emitFun $ FunDef Device fname (map ("int " ++) shapevars) _ - emitStm $ Launch _ _ _ + emitStm $ Launch _ (ILit "32") _ _ Device -> _ + _ -> undefined + writeType :: STy t -> M String writeType = \case STArr _ t -> (++ "*") <$> writeType t |