summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-01-27 23:08:14 +0100
committerTom Smeding <tom@tomsmeding.com>2024-01-27 23:08:14 +0100
commitf0150137969758ee7255ade3c90db915bc8542df (patch)
treed8b1ce58a503d39af08028a0cb90b912531c0532
parentd33a19b58de5af8017e1ae5fae06a9378379e20c (diff)
Ramblings in Compile
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/Compile.hs54
2 files changed, 25 insertions, 30 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 66452d9..c38c270 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -26,6 +26,7 @@ library
base >= 4.14 && < 4.19,
containers,
template-haskell,
+ transformers,
some
hs-source-dirs:
src
diff --git a/src/Compile.hs b/src/Compile.hs
index 2fcff5d..749368a 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -4,14 +4,17 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Compile where
-import Control.Monad (ap)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict
+import Control.Monad.Trans.Writer.CPS
import Data.Kind (Type)
-import Data.Map.Strict (Map)
-import qualified Data.Map.Strict as Map
import AST
+import Data
data Body = Body [Stm] Inline -- body, return expr
@@ -48,35 +51,23 @@ prj = \env idx -> go idx env
go IZ (EPush v _) = v
go (IS i) (EPush _ env) = go i env
--- generated global function definitions, generated local statements, function typedef cache (name, decl)
-newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a))
- deriving (Functor)
-instance Applicative M where
- pure x = M (\i m -> ([], [], m, i, x))
- (<*>) = ap
-instance Monad M where
- M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m
- (d2, s2, m2, i2, y) = let M h = g x in h i1 m1
- in (d1 <> d2, s1 <> s2, m2, i2, y))
+newtype M a = M (StateT Int -- ID generator
+ (WriterT [FunDef] -- generated global function definitions
+ (Writer [Stm])) -- generated local statements
+ a)
+ deriving newtype (Functor, Applicative, Monad)
emitFun :: FunDef -> M ()
-emitFun fd = M (\i m -> ([fd], [], m, i, ()))
+emitFun fd = M (lift (tell [fd]))
emitStm :: Stm -> M ()
-emitStm stm = M (\i m -> ([], [stm], m, i, ()))
+emitStm stm = M (lift (lift (tell [stm])))
captureStms :: M a -> M ([Stm], a)
-captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m
- in (d, [], m2, i2, (s, x)))
+captureStms (M m) = M (mapStateT (mapWriterT (mapWriter (\(((x, i), fds), stms) -> ((((stms, x), i), fds), [])))) m)
genId :: M Int
-genId = M (\i m -> ([], [], m, i + 1, i))
-
-getTypedef :: Ty -> M (Maybe String)
-getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m)
-
-putTypedef :: Ty -> String -> String -> M ()
-putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ())
+genId = M (state (\i -> (i, i + 1)))
genName :: String -> M String
genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
@@ -85,8 +76,7 @@ genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
c : _ | c `elem` "0123456789_" -> ("_", "")
| otherwise -> ("", "")
--- Function values are returned as a function-pointer-typed expression
-compile :: Target -> Env env String -> Expr x env t -> M Inline
+compile :: Target -> Env env String -> Ex env t -> M Inline
compile tgt env = \case
EVar _ _ i -> pure $ IVar (prj env i)
ELet _ rhs e -> do
@@ -96,17 +86,21 @@ compile tgt env = \case
emitStm $ VarDef rhsty var (Just rhsi)
compile tgt (EPush var env) e
- EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e
- EBuild x n k e -> case tgt of
+ EBuild1 x k e -> compile tgt env $ EBuild x (k :< VNil) e
+ EBuild _ VNil e -> _
+ EBuild _ k e -> case tgt of
Host -> do
fname <- genName "buildfun"
- let n' = fromNat (unSNat n)
+ let n' = fromSNat (vecLength k)
shapevars = ['s' : show i | i <- [0 .. n' - 1]]
+ _ = foldr (\a b -> EOp ext (OMul STI64) (EPair ext a b)) (EConst ext STI64 1) k
emitFun $ FunDef Device fname (map ("int " ++) shapevars) _
- emitStm $ Launch _ _ _
+ emitStm $ Launch _ (ILit "32") _
_
Device -> _
+ _ -> undefined
+
writeType :: STy t -> M String
writeType = \case
STArr _ t -> (++ "*") <$> writeType t