summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2023-09-19 21:55:38 +0200
committerTom Smeding <tom@tomsmeding.com>2023-09-19 21:55:38 +0200
commit183e8b4a07231aae904b8234ddeb1c646c031173 (patch)
treed514667bb46f5bf6553abed83042a5771e3c39f2
parent7095bcf4910e2b1525234ca8e88f4effc25315bd (diff)
Stuff
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs29
-rw-r--r--src/AST/Pretty.hs72
-rw-r--r--src/Example.hs18
-rw-r--r--src/Simplify.hs164
5 files changed, 267 insertions, 17 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ac0df0f..df39a18 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -17,6 +17,7 @@ library
-- Compile
Example
PreludeCu
+ Simplify
other-modules:
build-depends:
base >= 4.14 && < 4.19,
diff --git a/src/AST.hs b/src/AST.hs
index 7c5de11..dfc114d 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -11,6 +11,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveTraversable #-}
module AST (module AST, module AST.Weaken) where
import Data.Functor.Const
@@ -30,11 +31,12 @@ data SNat n where
deriving instance (Show (SNat n))
data Vec n t where
- VNil :: Vec n t
+ VNil :: Vec Z t
(:<) :: t -> Vec n t -> Vec (S n) t
deriving instance Show t => Show (Vec n t)
deriving instance Functor (Vec n)
deriving instance Foldable (Vec n)
+deriving instance Traversable (Vec n)
data SList f l where
SNil :: SList f '[]
@@ -92,6 +94,9 @@ type family ConsN n x l where
ConsN Z x l = l
ConsN (S n) x l = x : ConsN n x l
+-- General assumption: head of the list (whatever way it is associated) is the
+-- inner variable / inner array dimension. In pretty printing, the inner
+-- variable / inner dimension is printed on the _right_.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
-- lambda calculus
@@ -109,7 +114,7 @@ data Expr x env t where
-- array operations
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
- EBuild :: x (TArr n t) -> SNat n -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
+ EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
-- expression operations
@@ -168,7 +173,7 @@ typeOf = \case
ECase _ _ a _ -> typeOf a
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
- EBuild _ n _ e -> STArr n (typeOf e)
+ EBuild _ es e -> STArr (vecLength es) (typeOf e)
EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
EConst _ t _ -> STScal t
@@ -212,6 +217,10 @@ fromNat :: Nat -> Int
fromNat Z = 0
fromNat (S n) = succ (fromNat n)
+vecLength :: Vec n t -> SNat n
+vecLength VNil = SZ
+vecLength (_ :< v) = SS (vecLength v)
+
infixr @>
(@>) :: env :> env' -> Idx env t -> Idx env' t
WId @> i = i
@@ -233,7 +242,7 @@ weakenExpr w = \case
EInr x t e -> EInr x t (weakenExpr w e)
ECase x e1 e2 e3 -> ECase x (weakenExpr w e1) (weakenExpr (WCopy w) e2) (weakenExpr (WCopy w) e3)
EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2)
- EBuild x n es e -> EBuild x n (weakenVec w es) (weakenExpr (wcopyN n w) e)
+ EBuild x es e -> EBuild x (weakenVec w es) (weakenExpr (wcopyN (vecLength es) w) e)
EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2)
EConst x t v -> EConst x t v
EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2)
@@ -245,10 +254,18 @@ weakenExpr w = \case
EMBind e1 e2 -> EMBind (weakenExpr w e1) (weakenExpr (WCopy w) e2)
EError t s -> EError t s
+wsinkN :: SNat n -> env :> ConsN n TIx env
+wsinkN SZ = WId
+wsinkN (SS n) = WSink .> wsinkN n
+
wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env'
wcopyN SZ w = w
wcopyN (SS n) w = WCopy (wcopyN n w)
+wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env'
+wpopN SZ w = w
+wpopN (SS n) w = wpopN n (WPop w)
+
weakenVec :: (env :> env') -> Vec n (Expr x env TIx) -> Vec n (Expr x env' TIx)
weakenVec _ VNil = VNil
weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v
@@ -256,3 +273,7 @@ weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v
slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list
slistMap _ SNil = SNil
slistMap f (SCons x list) = SCons (f x) (slistMap f list)
+
+idx2int :: Idx env t -> Int
+idx2int IZ = 0
+idx2int (IS n) = 1 + idx2int n
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index c1d6c88..e793ce1 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -10,6 +10,7 @@ module AST.Pretty where
import Control.Monad (ap)
import Data.List (intersperse)
+import Data.Foldable (toList)
import Data.Functor.Const
import AST
@@ -26,6 +27,10 @@ valprj (VPush x _) IZ = x
valprj (VPush _ env) (IS i) = valprj env i
valprj VTop i = case i of {}
+vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env)
+vpushN VNil v = v
+vpushN (name :< names) v = VPush (Const name) (vpushN names v)
+
newtype M a = M { runM :: Int -> (a, Int) }
deriving (Functor)
instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
@@ -88,11 +93,11 @@ ppExpr' d val = \case
EInl _ _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "inl " . e'
+ return $ showParen (d > 10) $ showString "Inl " . e'
EInr _ _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "inr " . e'
+ return $ showParen (d > 10) $ showString "Inr " . e'
ECase _ e a b -> do
e' <- ppExpr' 0 val e
@@ -104,9 +109,50 @@ ppExpr' d val = \case
showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
. showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"
+ EBuild1 _ a b -> do
+ a' <- ppExpr' 11 val a
+ name <- genName
+ b' <- ppExpr' 0 (VPush (Const name) val) b
+ return $ showParen (d > 10) $
+ showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
+
+ EBuild _ es e -> do
+ es' <- mapM (ppExpr' 0 val) es
+ names <- mapM (const genName) es
+ e' <- ppExpr' 0 (vpushN names val) e
+ return $ showParen (d > 10) $
+ showString "build ["
+ . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
+ . showString "] (\\["
+ . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))
+ . showString ("] -> ") . e' . showString ")"
+
+ EFold1 _ a b -> do
+ name1 <- genName
+ name2 <- genName
+ a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
+ b' <- ppExpr' 11 val b
+ return $ showParen (d > 10) $
+ showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
+ . showString ") " . b'
+
EConst _ ty v -> return $ showString $ case ty of
STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
+ EIdx1 _ a b -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 9 val b
+ return $ showParen (d > 8) $ a' . showString " ! " . b'
+
+ EIdx _ e es -> do
+ e' <- ppExpr' 9 val e
+ es' <- traverse (ppExpr' 0 val) es
+ return $ showParen (d > 8) $
+ e' . showString " ! "
+ . showString "["
+ . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
+ . showString "]"
+
EOp _ op (EPair _ a b)
| (Infix, ops) <- operator op -> do
a' <- ppExpr' 9 val a
@@ -139,6 +185,22 @@ ppExpr' d val = \case
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString ("return ") . e'
+ etop@(EMBind _ EMBind{}) -> do
+ let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
+ collect val' (EMBind lhs cont) = do
+ name <- genName
+ (binds, core) <- collect (VPush (Const name) val') cont
+ lhs' <- ppExpr' 0 val' lhs
+ return ((name, lhs') : binds, core)
+ collect val' e = ([],) <$> ppExpr' 0 val' e
+
+ (binds, core) <- collect val etop
+ return $ showParen (d > 0) $
+ showString "do { "
+ . foldr (.) id (intersperse (showString " ; ")
+ (map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds))
+ . showString " ; " . core . showString " }"
+
EMBind a b -> do
a' <- ppExpr' 0 val a
name <- genName
@@ -147,8 +209,6 @@ ppExpr' d val = \case
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
- _ -> undefined
-
data Fixity = Prefix | Infix
deriving (Show)
@@ -160,7 +220,3 @@ operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
-
-idx2int :: Idx env t -> Int
-idx2int IZ = 0
-idx2int (IS n) = 1 + idx2int n
diff --git a/src/Example.hs b/src/Example.hs
index 99574c5..c8f12ba 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -4,6 +4,7 @@ module Example where
import AST
import AST.Pretty
import CHAD
+import Simplify
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
@@ -18,12 +19,12 @@ senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
-- x4 = ((*) x3, x1)
-- in ( (+) x4
-- , let x5 = 1.0
--- x6 = inr (x5, x5)
+-- x6 = Inr (x5, x5)
-- in case x6 of
-- Inl x7 -> return ()
-- Inr x8 ->
-- let x9 = fst x8
--- x10 = inr (snd x3 * x9, fst x3 * x9)
+-- x10 = Inr (snd x3 * x9, fst x3 * x9)
-- in case x10 of
-- Inl x11 -> return ()
-- Inr x12 ->
@@ -47,6 +48,13 @@ ex1 =
(EVar ext (STScal STF32) IZ))
(EVar ext (STScal STF32) (IS IZ))
--- -- x y |- let z = x + y in z * (z + x)
--- ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
--- ex2 = _
+-- x y |- let z = x + y in z * (z + x)
+ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex2 =
+ ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
+ (EVar ext (STScal STF32) IZ)) $
+ bin (OMul STF32)
+ (EVar ext (STScal STF32) IZ)
+ (bin (OAdd STF32)
+ (EVar ext (STScal STF32) IZ)
+ (EVar ext (STScal STF32) (IS (IS IZ))))
diff --git a/src/Simplify.hs b/src/Simplify.hs
new file mode 100644
index 0000000..cb649d5
--- /dev/null
+++ b/src/Simplify.hs
@@ -0,0 +1,164 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE DataKinds #-}
+module Simplify where
+
+import AST
+
+
+simplify :: Ex env t -> Ex env t
+simplify = \case
+ -- inlining
+ ELet _ rhs body
+ | Occ lexOcc runOcc <- occCount IZ body
+ , lexOcc <= One -- prevent code size blowup
+ , runOcc <= One -- prevent runtime increase
+ -> simplify (subst1 rhs body)
+ | cheapExpr rhs
+ -> simplify (subst1 rhs body)
+
+ -- let splitting
+ ELet _ (EPair _ a b) body ->
+ simplify $
+ ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
+ IS i -> EVar ext t (IS (IS i)))
+ body
+
+ EFst _ (EPair _ e _) -> simplify e
+ ESnd _ (EPair _ _ e) -> simplify e
+
+ ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
+ ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)
+
+ -- TODO: array indexing (index of build, index of fold)
+
+ -- TODO: constant folding for operations
+
+ EVar _ t i -> EVar ext t i
+ ELet _ a b -> ELet ext (simplify a) (simplify b)
+ EPair _ a b -> EPair ext (simplify a) (simplify b)
+ EFst _ e -> EFst ext (simplify e)
+ ESnd _ e -> ESnd ext (simplify e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext t (simplify e)
+ EInr _ t e -> EInr ext t (simplify e)
+ ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b)
+ EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b)
+ EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e)
+ EFold1 _ a b -> EFold1 ext (simplify a) (simplify b)
+ EConst _ t v -> EConst ext t v
+ EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b)
+ EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es)
+ EOp _ op e -> EOp ext op (simplify e)
+ EMOne t i e -> EMOne t i (simplify e)
+ EMScope e -> EMScope (simplify e)
+ EMReturn t e -> EMReturn t (simplify e)
+ EMBind a b -> EMBind (simplify a) (simplify b)
+ EError t s -> EError t s
+
+cheapExpr :: Expr x env t -> Bool
+cheapExpr = \case
+ EVar{} -> True
+ ENil{} -> True
+ EConst{} -> True
+ _ -> False
+
+data Count = Zero | One | Many
+ deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+ Zero <> n = n
+ n <> Zero = n
+ _ <> _ = Many
+instance Monoid Count where
+ mempty = Zero
+
+data Occ = Occ { _occLexical :: Count
+ , _occRuntime :: Count }
+instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
+instance Monoid Occ where mempty = Occ mempty mempty
+
+-- | One of the two branches is taken
+(<||>) :: Occ -> Occ -> Occ
+Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+
+-- | This code is executed many times
+scaleMany :: Occ -> Occ
+scaleMany (Occ l _) = Occ l Many
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx = \case
+ EVar _ _ i | idx2int i == idx2int idx -> Occ One One
+ | otherwise -> mempty
+ ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
+ EPair _ a b -> occCount idx a <> occCount idx b
+ EFst _ e -> occCount idx e
+ ESnd _ e -> occCount idx e
+ ENil _ -> mempty
+ EInl _ _ e -> occCount idx e
+ EInr _ _ e -> occCount idx e
+ ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
+ EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
+ EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
+ EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
+ EConst{} -> mempty
+ EIdx1 _ a b -> occCount idx a <> occCount idx b
+ EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
+ EOp _ _ e -> occCount idx e
+ EMOne _ _ e -> occCount idx e
+ EMScope e -> occCount idx e
+ EMReturn _ e -> occCount idx e
+ EMBind a b -> occCount idx a <> occCount (IS idx) b
+ EError{} -> mempty
+
+subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+subst1 repl = subst $ \x t -> \case IZ -> repl
+ IS i -> EVar x t i
+
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
+ -> Expr x env t -> Expr x env' t
+subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
+
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+ -> env' :> envOut
+ -> Expr x env t
+ -> Expr x envOut t
+subst' f w = \case
+ EVar x t i -> f x t w i
+ ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
+ EPair x a b -> EPair x (subst' f w a) (subst' f w b)
+ EFst x e -> EFst x (subst' f w e)
+ ESnd x e -> ESnd x (subst' f w e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (subst' f w e)
+ EInr x t e -> EInr x t (subst' f w e)
+ ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
+ EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ EConst x t v -> EConst x t v
+ EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
+ EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
+ EOp x op e -> EOp x op (subst' f w e)
+ EMOne t i e -> EMOne t i (subst' f w e)
+ EMScope e -> EMScope (subst' f w e)
+ EMReturn t e -> EMReturn t (subst' f w e)
+ EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EError t s -> EError t s
+ where
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF f' x' t w' = \case
+ IZ -> EVar x' t (w' @> IZ)
+ IS i -> f' x' t (WPop w') i
+
+ sinkFN :: SNat n
+ -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
+ sinkFN SZ f' x t w' i = f' x t w' i
+ sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
+ sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i