diff options
author | Tom Smeding <tom@tomsmeding.com> | 2023-09-19 21:55:38 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2023-09-19 21:55:38 +0200 |
commit | 183e8b4a07231aae904b8234ddeb1c646c031173 (patch) | |
tree | d514667bb46f5bf6553abed83042a5771e3c39f2 | |
parent | 7095bcf4910e2b1525234ca8e88f4effc25315bd (diff) |
Stuff
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 29 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 72 | ||||
-rw-r--r-- | src/Example.hs | 18 | ||||
-rw-r--r-- | src/Simplify.hs | 164 |
5 files changed, 267 insertions, 17 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index ac0df0f..df39a18 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -17,6 +17,7 @@ library -- Compile Example PreludeCu + Simplify other-modules: build-depends: base >= 4.14 && < 4.19, @@ -11,6 +11,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveTraversable #-} module AST (module AST, module AST.Weaken) where import Data.Functor.Const @@ -30,11 +31,12 @@ data SNat n where deriving instance (Show (SNat n)) data Vec n t where - VNil :: Vec n t + VNil :: Vec Z t (:<) :: t -> Vec n t -> Vec (S n) t deriving instance Show t => Show (Vec n t) deriving instance Functor (Vec n) deriving instance Foldable (Vec n) +deriving instance Traversable (Vec n) data SList f l where SNil :: SList f '[] @@ -92,6 +94,9 @@ type family ConsN n x l where ConsN Z x l = l ConsN (S n) x l = x : ConsN n x l +-- General assumption: head of the list (whatever way it is associated) is the +-- inner variable / inner array dimension. In pretty printing, the inner +-- variable / inner dimension is printed on the _right_. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus @@ -109,7 +114,7 @@ data Expr x env t where -- array operations EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) - EBuild :: x (TArr n t) -> SNat n -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) + EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) -- expression operations @@ -168,7 +173,7 @@ typeOf = \case ECase _ _ a _ -> typeOf a EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) - EBuild _ n _ e -> STArr n (typeOf e) + EBuild _ es e -> STArr (vecLength es) (typeOf e) EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t EConst _ t _ -> STScal t @@ -212,6 +217,10 @@ fromNat :: Nat -> Int fromNat Z = 0 fromNat (S n) = succ (fromNat n) +vecLength :: Vec n t -> SNat n +vecLength VNil = SZ +vecLength (_ :< v) = SS (vecLength v) + infixr @> (@>) :: env :> env' -> Idx env t -> Idx env' t WId @> i = i @@ -233,7 +242,7 @@ weakenExpr w = \case EInr x t e -> EInr x t (weakenExpr w e) ECase x e1 e2 e3 -> ECase x (weakenExpr w e1) (weakenExpr (WCopy w) e2) (weakenExpr (WCopy w) e3) EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2) - EBuild x n es e -> EBuild x n (weakenVec w es) (weakenExpr (wcopyN n w) e) + EBuild x es e -> EBuild x (weakenVec w es) (weakenExpr (wcopyN (vecLength es) w) e) EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2) EConst x t v -> EConst x t v EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2) @@ -245,10 +254,18 @@ weakenExpr w = \case EMBind e1 e2 -> EMBind (weakenExpr w e1) (weakenExpr (WCopy w) e2) EError t s -> EError t s +wsinkN :: SNat n -> env :> ConsN n TIx env +wsinkN SZ = WId +wsinkN (SS n) = WSink .> wsinkN n + wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' wcopyN SZ w = w wcopyN (SS n) w = WCopy (wcopyN n w) +wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env' +wpopN SZ w = w +wpopN (SS n) w = wpopN n (WPop w) + weakenVec :: (env :> env') -> Vec n (Expr x env TIx) -> Vec n (Expr x env' TIx) weakenVec _ VNil = VNil weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v @@ -256,3 +273,7 @@ weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list slistMap _ SNil = SNil slistMap f (SCons x list) = SCons (f x) (slistMap f list) + +idx2int :: Idx env t -> Int +idx2int IZ = 0 +idx2int (IS n) = 1 + idx2int n diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index c1d6c88..e793ce1 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -10,6 +10,7 @@ module AST.Pretty where import Control.Monad (ap) import Data.List (intersperse) +import Data.Foldable (toList) import Data.Functor.Const import AST @@ -26,6 +27,10 @@ valprj (VPush x _) IZ = x valprj (VPush _ env) (IS i) = valprj env i valprj VTop i = case i of {} +vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env) +vpushN VNil v = v +vpushN (name :< names) v = VPush (Const name) (vpushN names v) + newtype M a = M { runM :: Int -> (a, Int) } deriving (Functor) instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } @@ -88,11 +93,11 @@ ppExpr' d val = \case EInl _ _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "inl " . e' + return $ showParen (d > 10) $ showString "Inl " . e' EInr _ _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "inr " . e' + return $ showParen (d > 10) $ showString "Inr " . e' ECase _ e a b -> do e' <- ppExpr' 0 val e @@ -104,9 +109,50 @@ ppExpr' d val = \case showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" + EBuild1 _ a b -> do + a' <- ppExpr' 11 val a + name <- genName + b' <- ppExpr' 0 (VPush (Const name) val) b + return $ showParen (d > 10) $ + showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" + + EBuild _ es e -> do + es' <- mapM (ppExpr' 0 val) es + names <- mapM (const genName) es + e' <- ppExpr' 0 (vpushN names val) e + return $ showParen (d > 10) $ + showString "build [" + . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) + . showString "] (\\[" + . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names)))) + . showString ("] -> ") . e' . showString ")" + + EFold1 _ a b -> do + name1 <- genName + name2 <- genName + a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a + b' <- ppExpr' 11 val b + return $ showParen (d > 10) $ + showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' + . showString ") " . b' + EConst _ ty v -> return $ showString $ case ty of STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v + EIdx1 _ a b -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 9 val b + return $ showParen (d > 8) $ a' . showString " ! " . b' + + EIdx _ e es -> do + e' <- ppExpr' 9 val e + es' <- traverse (ppExpr' 0 val) es + return $ showParen (d > 8) $ + e' . showString " ! " + . showString "[" + . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) + . showString "]" + EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do a' <- ppExpr' 9 val a @@ -139,6 +185,22 @@ ppExpr' d val = \case e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString ("return ") . e' + etop@(EMBind _ EMBind{}) -> do + let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS) + collect val' (EMBind lhs cont) = do + name <- genName + (binds, core) <- collect (VPush (Const name) val') cont + lhs' <- ppExpr' 0 val' lhs + return ((name, lhs') : binds, core) + collect val' e = ([],) <$> ppExpr' 0 val' e + + (binds, core) <- collect val etop + return $ showParen (d > 0) $ + showString "do { " + . foldr (.) id (intersperse (showString " ; ") + (map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds)) + . showString " ; " . core . showString " }" + EMBind a b -> do a' <- ppExpr' 0 val a name <- genName @@ -147,8 +209,6 @@ ppExpr' d val = \case EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) - _ -> undefined - data Fixity = Prefix | Infix deriving (Show) @@ -160,7 +220,3 @@ operator OLt{} = (Infix, "<") operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") - -idx2int :: Idx env t -> Int -idx2int IZ = 0 -idx2int (IS n) = 1 + idx2int n diff --git a/src/Example.hs b/src/Example.hs index 99574c5..c8f12ba 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -4,6 +4,7 @@ module Example where import AST import AST.Pretty import CHAD +import Simplify bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c @@ -18,12 +19,12 @@ senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil -- x4 = ((*) x3, x1) -- in ( (+) x4 -- , let x5 = 1.0 --- x6 = inr (x5, x5) +-- x6 = Inr (x5, x5) -- in case x6 of -- Inl x7 -> return () -- Inr x8 -> -- let x9 = fst x8 --- x10 = inr (snd x3 * x9, fst x3 * x9) +-- x10 = Inr (snd x3 * x9, fst x3 * x9) -- in case x10 of -- Inl x11 -> return () -- Inr x12 -> @@ -47,6 +48,13 @@ ex1 = (EVar ext (STScal STF32) IZ)) (EVar ext (STScal STF32) (IS IZ)) --- -- x y |- let z = x + y in z * (z + x) --- ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) --- ex2 = _ +-- x y |- let z = x + y in z * (z + x) +ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex2 = + ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) + (EVar ext (STScal STF32) IZ)) $ + bin (OMul STF32) + (EVar ext (STScal STF32) IZ) + (bin (OAdd STF32) + (EVar ext (STScal STF32) IZ) + (EVar ext (STScal STF32) (IS (IS IZ)))) diff --git a/src/Simplify.hs b/src/Simplify.hs new file mode 100644 index 0000000..cb649d5 --- /dev/null +++ b/src/Simplify.hs @@ -0,0 +1,164 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +module Simplify where + +import AST + + +simplify :: Ex env t -> Ex env t +simplify = \case + -- inlining + ELet _ rhs body + | Occ lexOcc runOcc <- occCount IZ body + , lexOcc <= One -- prevent code size blowup + , runOcc <= One -- prevent runtime increase + -> simplify (subst1 rhs body) + | cheapExpr rhs + -> simplify (subst1 rhs body) + + -- let splitting + ELet _ (EPair _ a b) body -> + simplify $ + ELet ext a $ + ELet ext (weakenExpr WSink b) $ + subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) + IS i -> EVar ext t (IS (IS i))) + body + + EFst _ (EPair _ e _) -> simplify e + ESnd _ (EPair _ _ e) -> simplify e + + ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs) + ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs) + + -- TODO: array indexing (index of build, index of fold) + + -- TODO: constant folding for operations + + EVar _ t i -> EVar ext t i + ELet _ a b -> ELet ext (simplify a) (simplify b) + EPair _ a b -> EPair ext (simplify a) (simplify b) + EFst _ e -> EFst ext (simplify e) + ESnd _ e -> ESnd ext (simplify e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext t (simplify e) + EInr _ t e -> EInr ext t (simplify e) + ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b) + EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b) + EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e) + EFold1 _ a b -> EFold1 ext (simplify a) (simplify b) + EConst _ t v -> EConst ext t v + EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b) + EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es) + EOp _ op e -> EOp ext op (simplify e) + EMOne t i e -> EMOne t i (simplify e) + EMScope e -> EMScope (simplify e) + EMReturn t e -> EMReturn t (simplify e) + EMBind a b -> EMBind (simplify a) (simplify b) + EError t s -> EError t s + +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + _ -> False + +data Count = Zero | One | Many + deriving (Show, Eq, Ord) + +instance Semigroup Count where + Zero <> n = n + n <> Zero = n + _ <> _ = Many +instance Monoid Count where + mempty = Zero + +data Occ = Occ { _occLexical :: Count + , _occRuntime :: Count } +instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d) +instance Monoid Occ where mempty = Occ mempty mempty + +-- | One of the two branches is taken +(<||>) :: Occ -> Occ -> Occ +Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) + +-- | This code is executed many times +scaleMany :: Occ -> Occ +scaleMany (Occ l _) = Occ l Many + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx = \case + EVar _ _ i | idx2int i == idx2int idx -> Occ One One + | otherwise -> mempty + ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body + EPair _ a b -> occCount idx a <> occCount idx b + EFst _ e -> occCount idx e + ESnd _ e -> occCount idx e + ENil _ -> mempty + EInl _ _ e -> occCount idx e + EInr _ _ e -> occCount idx e + ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) + EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) + EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) + EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b + EConst{} -> mempty + EIdx1 _ a b -> occCount idx a <> occCount idx b + EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es + EOp _ _ e -> occCount idx e + EMOne _ _ e -> occCount idx e + EMScope e -> occCount idx e + EMReturn _ e -> occCount idx e + EMBind a b -> occCount idx a <> occCount (IS idx) b + EError{} -> mempty + +subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t +subst1 repl = subst $ \x t -> \case IZ -> repl + IS i -> EVar x t i + +subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) + -> Expr x env t -> Expr x env' t +subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId + +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) + -> env' :> envOut + -> Expr x env t + -> Expr x envOut t +subst' f w = \case + EVar x t i -> f x t w i + ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) + EPair x a b -> EPair x (subst' f w a) (subst' f w b) + EFst x e -> EFst x (subst' f w e) + ESnd x e -> ESnd x (subst' f w e) + ENil x -> ENil x + EInl x t e -> EInl x t (subst' f w e) + EInr x t e -> EInr x t (subst' f w e) + ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) + EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) + EConst x t v -> EConst x t v + EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) + EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) + EOp x op e -> EOp x op (subst' f w e) + EMOne t i e -> EMOne t i (subst' f w e) + EMScope e -> EMScope (subst' f w e) + EMReturn t e -> EMReturn t (subst' f w e) + EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EError t s -> EError t s + where + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF f' x' t w' = \case + IZ -> EVar x' t (w' @> IZ) + IS i -> f' x' t (WPop w') i + + sinkFN :: SNat n + -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t + sinkFN SZ f' x t w' i = f' x t w' i + sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) + sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i |