summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-05-25 23:35:31 +0200
committerTom Smeding <tom@tomsmeding.com>2025-05-25 23:35:31 +0200
commitc36849cb6247f957b4e6b093e16d04421c8cea3d (patch)
treefdcdcec5c598c95c493ede2782a96563a32b4b5f
parentb0b562e5000dbcac8b944801e7ab96556855a4ff (diff)
ERecompute
-rw-r--r--src/AST.hs7
-rw-r--r--src/AST/Bindings.hs11
-rw-r--r--src/AST/Count.hs1
-rw-r--r--src/AST/Pretty.hs4
-rw-r--r--src/AST/SplitLets.hs1
-rw-r--r--src/AST/UnMonoid.hs1
-rw-r--r--src/Analysis/Identity.hs4
-rw-r--r--src/CHAD.hs41
-rw-r--r--src/Compile.hs2
-rw-r--r--src/ForwardAD/DualNumbers.hs1
-rw-r--r--src/Interpreter.hs1
-rw-r--r--src/Language.hs3
-rw-r--r--src/Language/AST.hs4
-rw-r--r--src/Simplify.hs2
14 files changed, 74 insertions, 9 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 65664fc..149cddd 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -87,6 +87,9 @@ data Expr x env t where
-> Expr x env a -> Expr x env b
-> Expr x env t
+ -- fake halfway checkpointing
+ ERecompute :: x t -> Expr x env t -> Expr x env t
+
-- accumulation effect on monoids
EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil
@@ -212,6 +215,7 @@ typeOf = \case
EOp _ op _ -> opt2 op
ECustom _ _ _ _ e _ _ _ _ -> typeOf e
+ ERecompute _ e -> typeOf e
EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
EAccum _ _ _ _ _ _ -> STNil
@@ -255,6 +259,7 @@ extOf = \case
EShape x _ -> x
EOp x _ _ -> x
ECustom x _ _ _ _ _ _ _ _ -> x
+ ERecompute x _ -> x
EWith x _ _ _ -> x
EAccum x _ _ _ _ _ -> x
EZero x _ _ -> x
@@ -299,6 +304,7 @@ travExt f = \case
EShape x e -> EShape <$> f x <*> travExt f e
EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
+ ERecompute x e -> ERecompute <$> f x <*> travExt f e
EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3
EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
@@ -356,6 +362,7 @@ subst' f w = \case
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
+ ERecompute x e -> ERecompute x (subst' f w e)
EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
EZero x t e -> EZero x t (subst' f w e)
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 3d99afe..745a93b 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -16,6 +16,7 @@
module AST.Bindings where
import AST
+import AST.Env
import Data
import Lemmas
@@ -62,3 +63,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
+collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
+collectBindings = \env -> fst . go env WId
+ where
+ go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
+ go _ _ SETop = (BTop, WId)
+ go (ty `SCons` env) w (SEYes sub) =
+ let (bs, w') = go env (WPop w) sub
+ in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
+ go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index feaaa1e..0c682c6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -132,6 +132,7 @@ occCountGeneral onehot unpush alter many = go WId
EShape _ e -> re e
EOp _ _ e -> re e
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
+ ERecompute _ e -> re e
EWith _ _ a b -> re a <> re1 b
EAccum _ _ _ a b e -> re a <> re b <> re e
EZero _ _ e -> re e
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 6d70ca3..41da656 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -288,6 +288,10 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 1379e35..3c353d4 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -61,6 +61,7 @@ splitLets' = \sub -> \case
EShape x e -> EShape x (splitLets' sub e)
EOp x op e -> EOp x op (splitLets' sub e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ ERecompute x e -> ERecompute x (splitLets' sub e)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
EZero x t ezi -> EZero x t (splitLets' sub ezi)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 3d5f544..ac4d733 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -47,6 +47,7 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index a1a6376..4501c32 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -294,6 +294,10 @@ idana env expr = case expr of
res <- genIds t4
pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
+ ERecompute _ e -> do
+ (v, e') <- idana env e
+ pure (v, ERecompute v e')
+
EWith _ t e1 e2 -> do
let t1 = typeOf e1
(_, e1') <- idana env e1
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 3a7b907..df792ce 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -62,14 +62,14 @@ tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
-bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds
+bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds
-> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-bindingsCollect BTop SETop _ = ENil ext
-bindingsCollect (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape BTop SETop _ = ENil ext
+bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w =
EPair ext (EVar ext t (w @> IZ))
- (bindingsCollect binds sub (w .> WSink))
-bindingsCollect (BPush binds _) (SENo sub) w =
- bindingsCollect binds sub (w .> WSink)
+ (bindingsCollectTape binds sub (w .> WSink))
+bindingsCollectTape (BPush binds _) (SENo sub) w =
+ bindingsCollectTape binds sub (w .> WSink)
-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
@@ -718,8 +718,8 @@ drev des accumMap = \case
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
, let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB)
- , let collectA = bindingsCollect a0 subtapeA
- , let collectB = bindingsCollect b0 subtapeB
+ , let collectA = bindingsCollectTape a0 subtapeA
+ , let collectB = bindingsCollectTape b0 subtapeB
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
, let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
@@ -822,6 +822,29 @@ drev des accumMap = \case
(ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
weakenExpr (WCopy (WSink .> WSink)) b2)
+ -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal
+ ERecompute _ e ->
+ deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
+ let smallE = unsafeWeakenWithSubenv usedSub e in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 ->
+ Ret (collectBindings (desD1E des) subD1eUsed)
+ (subenvAll (desD1E usedDes))
+ (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1)
+ (subenvCompose subMergeUsed sub)
+ (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
+ weakenExpr
+ (autoWeak (#d (auto1 @(D2 t))
+ &. #shbinds (bindingsBinds e0)
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #d1env (desD1E usedDes)
+ &. #tl' (d2ace (select SAccum usedDes))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
+ (#shbinds :++: #d :++: #d1env :++: #tl))
+ e2)
+ }
+
EError _ t s ->
Ret BTop
SETop
@@ -849,7 +872,7 @@ drev des accumMap = \case
case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollect e0 subtapeE in
+ let collectexpr = bindingsCollectTape e0 subtapeE in
Ret (BTop `BPush` (shty, letBinds she0 she1)
`BPush` (STArr ndim (STPair (d1 eltty) tapety)
,EBuild ext ndim
diff --git a/src/Compile.hs b/src/Compile.hs
index 9ed5a27..722b432 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -977,6 +977,8 @@ compile' env = \case
maybe (return ()) ($ name2) mfun2
return (CELit name)
+ ERecompute _ e -> compile' env e
+
EWith _ t e1 e2 -> do
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index ebc70d7..a6d5ec8 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -185,6 +185,7 @@ dfwdDN = \case
ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
+ ERecompute _ e -> dfwdDN e
EError _ t s -> EError ext (dn t) s
EWith{} -> err_accum
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index d7916d8..803a24a 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -153,6 +153,7 @@ interpret'Rec env = \case
e1' <- interpret' env e1
e2' <- interpret' env e2
interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
+ ERecompute _ e -> interpret' env e
EWith _ t e1 e2 -> do
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
diff --git a/src/Language.hs b/src/Language.hs
index 9fd5dd3..7a780a0 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -169,6 +169,9 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+recompute :: NExpr env a -> NExpr env a
+recompute = NERecompute
+
with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
with a (n :-> b) = NEWith (knownMTy @t) a n b
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 8bcb5e5..7e074df 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -71,6 +71,9 @@ data NExpr env t where
-> NExpr env a -> NExpr env b
-> NExpr env t
+ -- fake halfway checkpointing
+ NERecompute :: NExpr env t -> NExpr env t
+
-- accumulation effect on monoids
NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
@@ -215,6 +218,7 @@ fromNamedExpr val = \case
(fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
(fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
(go e1) (go e2)
+ NERecompute e -> ERecompute ext (go e)
NEWith t a n b -> EWith ext t (go a) (lambda val n b)
NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 6f97e6d..d963b7e 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -291,6 +291,7 @@ simplify'Rec = \case
e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
pure (ECustom ext s t p a' b' c' e1' e2')
+ ERecompute _ e -> [simprec| ERecompute ext *e |]
EWith _ t e1 e2 -> do
e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
@@ -345,6 +346,7 @@ hasAdds = \case
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
EWith _ _ a b -> hasAdds a || hasAdds b
+ ERecompute _ e -> hasAdds e
EAccum _ _ _ _ _ _ -> True
EZero _ _ e -> hasAdds e
EPlus _ _ a b -> hasAdds a || hasAdds b