summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs2
-rw-r--r--src/AST/Count.hs55
-rw-r--r--src/AST/Pretty.hs38
-rw-r--r--src/CHAD.hs60
-rw-r--r--src/Example.hs36
-rw-r--r--src/Simplify.hs58
7 files changed, 164 insertions, 86 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index df39a18..27c5520 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -11,6 +11,7 @@ build-type: Simple
library
exposed-modules:
AST
+ AST.Count
AST.Pretty
AST.Weaken
CHAD
diff --git a/src/AST.hs b/src/AST.hs
index dfc114d..8d795bf 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -147,6 +147,7 @@ data SOp a t where
OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
ONot :: SOp (TScal TBool) (TScal TBool)
+ OIf :: SOp (TScal TBool) (TEither TNil TNil)
deriving instance Show (SOp a t)
opt2 :: SOp a t -> STy t
@@ -158,6 +159,7 @@ opt2 = \case
OLe _ -> STScal STBool
OEq _ -> STScal STBool
ONot -> STScal STBool
+ OIf -> STEither STNil STNil
typeOf :: Expr x env t -> STy t
typeOf = \case
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
new file mode 100644
index 0000000..baf132e
--- /dev/null
+++ b/src/AST/Count.hs
@@ -0,0 +1,55 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+module AST.Count where
+
+import AST
+
+
+data Count = Zero | One | Many
+ deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+ Zero <> n = n
+ n <> Zero = n
+ _ <> _ = Many
+instance Monoid Count where
+ mempty = Zero
+
+data Occ = Occ { _occLexical :: Count
+ , _occRuntime :: Count }
+ deriving (Eq)
+instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
+instance Monoid Occ where mempty = Occ mempty mempty
+
+-- | One of the two branches is taken
+(<||>) :: Occ -> Occ -> Occ
+Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+
+-- | This code is executed many times
+scaleMany :: Occ -> Occ
+scaleMany (Occ l _) = Occ l Many
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx = \case
+ EVar _ _ i | idx2int i == idx2int idx -> Occ One One
+ | otherwise -> mempty
+ ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
+ EPair _ a b -> occCount idx a <> occCount idx b
+ EFst _ e -> occCount idx e
+ ESnd _ e -> occCount idx e
+ ENil _ -> mempty
+ EInl _ _ e -> occCount idx e
+ EInr _ _ e -> occCount idx e
+ ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
+ EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
+ EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
+ EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
+ EConst{} -> mempty
+ EIdx1 _ a b -> occCount idx a <> occCount idx b
+ EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
+ EOp _ _ e -> occCount idx e
+ EMOne _ _ e -> occCount idx e
+ EMScope e -> occCount idx e
+ EMReturn _ e -> occCount idx e
+ EMBind a b -> occCount idx a <> occCount (IS idx) b
+ EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index e793ce1..289294d 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -14,6 +14,7 @@ import Data.Foldable (toList)
import Data.Functor.Const
import AST
+import AST.Count
data Val f env where
@@ -42,6 +43,12 @@ genId = M (\i -> (i, i + 1))
genName :: M String
genName = ('x' :) . show <$> genId
+genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn ty idx ex
+ | occCount idx ex == mempty = case ty of STNil -> return "()"
+ _ -> return "_"
+ | otherwise = genName
+
ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
where
@@ -59,7 +66,7 @@ ppExpr' d val = \case
etop@ELet{} -> do
let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
collect val' (ELet _ rhs body) = do
- name <- genName
+ name <- genNameIfUsedIn (typeOf rhs) IZ body
(binds, core) <- collect (VPush (Const name) val') body
rhs' <- ppExpr' 0 val' rhs
return ((name, rhs') : binds, core)
@@ -101,9 +108,10 @@ ppExpr' d val = \case
ECase _ e a b -> do
e' <- ppExpr' 0 val e
- name1 <- genName
+ let STEither t1 t2 = typeOf e
+ name1 <- genNameIfUsedIn t1 IZ a
a' <- ppExpr' 0 (VPush (Const name1) val) a
- name2 <- genName
+ name2 <- genNameIfUsedIn t2 IZ b
b' <- ppExpr' 0 (VPush (Const name2) val) b
return $ showParen (d > 0) $
showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
@@ -111,14 +119,14 @@ ppExpr' d val = \case
EBuild1 _ a b -> do
a' <- ppExpr' 11 val a
- name <- genName
+ name <- genNameIfUsedIn (STScal STI64) IZ b
b' <- ppExpr' 0 (VPush (Const name) val) b
return $ showParen (d > 10) $
showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
EBuild _ es e -> do
es' <- mapM (ppExpr' 0 val) es
- names <- mapM (const genName) es
+ names <- mapM (const genName) es -- TODO generate underscores
e' <- ppExpr' 0 (vpushN names val) e
return $ showParen (d > 10) $
showString "build ["
@@ -128,8 +136,8 @@ ppExpr' d val = \case
. showString ("] -> ") . e' . showString ")"
EFold1 _ a b -> do
- name1 <- genName
- name2 <- genName
+ name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
+ name2 <- genNameIfUsedIn (typeOf a) IZ a
a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $
@@ -185,10 +193,11 @@ ppExpr' d val = \case
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString ("return ") . e'
- etop@(EMBind _ EMBind{}) -> do
+ etop@(EMBind _ _) -> do
let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
collect val' (EMBind lhs cont) = do
- name <- genName
+ let STEVM _ t = typeOf lhs
+ name <- genNameIfUsedIn t IZ cont
(binds, core) <- collect (VPush (Const name) val') cont
lhs' <- ppExpr' 0 val' lhs
return ((name, lhs') : binds, core)
@@ -201,11 +210,11 @@ ppExpr' d val = \case
(map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds))
. showString " ; " . core . showString " }"
- EMBind a b -> do
- a' <- ppExpr' 0 val a
- name <- genName
- b' <- ppExpr' 0 (VPush (Const name) val) b
- return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
+ -- EMBind a b -> do
+ -- a' <- ppExpr' 0 val a
+ -- name <- genNameIfUsedIn IZ b
+ -- b' <- ppExpr' 0 (VPush (Const name) val) b
+ -- return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
@@ -220,3 +229,4 @@ operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
+operator OIf = (Prefix, "ifB")
diff --git a/src/CHAD.hs b/src/CHAD.hs
index d0358b8..9a1c7d2 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -204,39 +204,44 @@ d1op (OLt t) e = EOp ext (OLt t) e
d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
+d1op OIf e = EOp ext OIf e
+
+data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
+ | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
-- both primal and dual must be duplicable expressions
-d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)
-d2op op e d = case op of
- OAdd _ -> EInr ext STNil (EPair ext d d)
- OMul t -> d2opBinArrangeInt t $
+d2op :: SOp a t -> D2Op a t
+d2op op = case op of
+ OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
(EOp ext (OMul t) (EPair ext (EFst ext e) d)))
- ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d
- OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- ONot -> ENil ext
+ ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
+ OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ ONot -> Linear $ \_ -> ENil ext
+ OIf -> Linear $ \_ -> ENil ext
where
d2opUnArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => Ex env (D2 (TScal a)))
- -> Ex env (D2 (TScal a))
+ -> (D2s a ~ TScal a => D2Op (TScal a) t)
+ -> D2Op (TScal a) t
d2opUnArrangeInt ty float = case ty of
- STI32 -> ENil ext
- STI64 -> ENil ext
+ STI32 -> Linear $ \_ -> ENil ext
+ STI64 -> Linear $ \_ -> ENil ext
STF32 -> float
STF64 -> float
- STBool -> ENil ext
+ STBool -> Linear $ \_ -> ENil ext
d2opBinArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a))))
- -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
+ -> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> EInl ext (STPair STNil STNil) (ENil ext)
- STI64 -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> EInl ext (STPair STNil STNil) (ENil ext)
+ STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
freezeRet :: Ret env t
-> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value
@@ -359,10 +364,17 @@ drev senv = \case
EOp _ op e
| Ret e0 e1 e2 <- drev senv e ->
- Ret (e0 `BPush` (d1 (typeOf e), e1))
- (d1op op $ EVar ext (d1 (typeOf e)) IZ)
- (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ))
- (EVar ext (d2 (opt2 op)) IZ))
- (weakenExpr (WCopy (wSinks @[_,_])) e2))
+ case d2op op of
+ Linear d2opfun ->
+ Ret e0
+ (d1op op e1)
+ (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))
+ (weakenExpr (WCopy WSink) e2))
+ Nonlinear d2opfun ->
+ Ret (e0 `BPush` (d1 (typeOf e), e1))
+ (d1op op $ EVar ext (d1 (typeOf e)) IZ)
+ (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
+ (EVar ext (d2 (opt2 op)) IZ))
+ (weakenExpr (WCopy (wSinks @[_,_])) e2))
e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)
diff --git a/src/Example.hs b/src/Example.hs
index c8f12ba..f2e5966 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -58,3 +58,39 @@ ex2 =
(bin (OAdd STF32)
(EVar ext (STScal STF32) IZ)
(EVar ext (STScal STF32) (IS (IS IZ))))
+
+-- x y |- if x < y then 2 * x else 3 + x
+ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex3 =
+ ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
+ (EVar ext (STScal STF32) IZ)))
+ (bin (OMul STF32) (EConst ext STF32 2.0)
+ (EVar ext (STScal STF32) (IS (IS IZ))))
+ (bin (OAdd STF32) (EConst ext STF32 3.0)
+ (EVar ext (STScal STF32) (IS (IS IZ))))
+
+-- x y |- if x < y then 2 * x + y * y else 3 + x
+ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex4 =
+ ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
+ (EVar ext (STScal STF32) IZ)))
+ (bin (OAdd STF32)
+ (bin (OMul STF32) (EConst ext STF32 2.0)
+ (EVar ext (STScal STF32) (IS (IS IZ))))
+ (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ))
+ (EVar ext (STScal STF32) (IS IZ))))
+ (bin (OAdd STF32) (EConst ext STF32 3.0)
+ (EVar ext (STScal STF32) (IS (IS IZ))))
+
+senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
+senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil
+
+-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
+ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
+ex5 =
+ ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ))
+ (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
+ (EVar ext (STScal STF32) (IS IZ)))
+ (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
+ (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
+ (EConst ext STF32 1.0)))
diff --git a/src/Simplify.hs b/src/Simplify.hs
index cb649d5..acc2392 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -6,8 +6,13 @@
module Simplify where
import AST
+import AST.Count
+simplifyN :: Int -> Ex env t -> Ex env t
+simplifyN 0 = id
+simplifyN n = simplifyN (n - 1) . simplify
+
simplify :: Ex env t -> Ex env t
simplify = \case
-- inlining
@@ -28,9 +33,11 @@ simplify = \case
IS i -> EVar ext t (IS (IS i)))
body
+ -- beta rules for products
EFst _ (EPair _ e _) -> simplify e
ESnd _ (EPair _ _ e) -> simplify e
+ -- beta rules for coproducts
ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)
@@ -38,6 +45,9 @@ simplify = \case
-- TODO: constant folding for operations
+ -- eta rule for return+bind
+ EMBind (EMReturn _ a) b -> simplify (ELet ext a b)
+
EVar _ t i -> EVar ext t i
ELet _ a b -> ELet ext (simplify a) (simplify b)
EPair _ a b -> EPair ext (simplify a) (simplify b)
@@ -67,54 +77,6 @@ cheapExpr = \case
EConst{} -> True
_ -> False
-data Count = Zero | One | Many
- deriving (Show, Eq, Ord)
-
-instance Semigroup Count where
- Zero <> n = n
- n <> Zero = n
- _ <> _ = Many
-instance Monoid Count where
- mempty = Zero
-
-data Occ = Occ { _occLexical :: Count
- , _occRuntime :: Count }
-instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
-instance Monoid Occ where mempty = Occ mempty mempty
-
--- | One of the two branches is taken
-(<||>) :: Occ -> Occ -> Occ
-Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-
--- | This code is executed many times
-scaleMany :: Occ -> Occ
-scaleMany (Occ l _) = Occ l Many
-
-occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx = \case
- EVar _ _ i | idx2int i == idx2int idx -> Occ One One
- | otherwise -> mempty
- ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
- EPair _ a b -> occCount idx a <> occCount idx b
- EFst _ e -> occCount idx e
- ESnd _ e -> occCount idx e
- ENil _ -> mempty
- EInl _ _ e -> occCount idx e
- EInr _ _ e -> occCount idx e
- ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
- EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
- EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
- EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
- EConst{} -> mempty
- EIdx1 _ a b -> occCount idx a <> occCount idx b
- EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
- EOp _ _ e -> occCount idx e
- EMOne _ _ e -> occCount idx e
- EMScope e -> occCount idx e
- EMReturn _ e -> occCount idx e
- EMBind a b -> occCount idx a <> occCount (IS idx) b
- EError{} -> mempty
-
subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
IS i -> EVar x t i