diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-05-25 23:35:31 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-05-25 23:35:31 +0200 |
commit | c36849cb6247f957b4e6b093e16d04421c8cea3d (patch) | |
tree | fdcdcec5c598c95c493ede2782a96563a32b4b5f | |
parent | b0b562e5000dbcac8b944801e7ab96556855a4ff (diff) |
ERecompute
-rw-r--r-- | src/AST.hs | 7 | ||||
-rw-r--r-- | src/AST/Bindings.hs | 11 | ||||
-rw-r--r-- | src/AST/Count.hs | 1 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 4 | ||||
-rw-r--r-- | src/AST/SplitLets.hs | 1 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 1 | ||||
-rw-r--r-- | src/Analysis/Identity.hs | 4 | ||||
-rw-r--r-- | src/CHAD.hs | 41 | ||||
-rw-r--r-- | src/Compile.hs | 2 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 1 | ||||
-rw-r--r-- | src/Interpreter.hs | 1 | ||||
-rw-r--r-- | src/Language.hs | 3 | ||||
-rw-r--r-- | src/Language/AST.hs | 4 | ||||
-rw-r--r-- | src/Simplify.hs | 2 |
14 files changed, 74 insertions, 9 deletions
@@ -87,6 +87,9 @@ data Expr x env t where -> Expr x env a -> Expr x env b -> Expr x env t + -- fake halfway checkpointing + ERecompute :: x t -> Expr x env t -> Expr x env t + -- accumulation effect on monoids EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil @@ -212,6 +215,7 @@ typeOf = \case EOp _ op _ -> opt2 op ECustom _ _ _ _ e _ _ _ _ -> typeOf e + ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ _ _ -> STNil @@ -255,6 +259,7 @@ extOf = \case EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x + ERecompute x _ -> x EWith x _ _ _ -> x EAccum x _ _ _ _ _ -> x EZero x _ _ -> x @@ -299,6 +304,7 @@ travExt f = \case EShape x e -> EShape <$> f x <*> travExt f e EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e @@ -356,6 +362,7 @@ subst' f w = \case EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) + ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 3d99afe..745a93b 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -16,6 +16,7 @@ module AST.Bindings where import AST +import AST.Env import Data import Lemmas @@ -62,3 +63,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t letBinds BTop = id letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + +collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' +collectBindings = \env -> fst . go env WId + where + go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) + go _ _ SETop = (BTop, WId) + go (ty `SCons` env) w (SEYes sub) = + let (bs, w') = go env (WPop w) sub + in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') + go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index feaaa1e..0c682c6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -132,6 +132,7 @@ occCountGeneral onehot unpush alter many = go WId EShape _ e -> re e EOp _ _ e -> re e ECustom _ _ _ _ _ _ _ a b -> re a <> re b + ERecompute _ e -> re e EWith _ _ a b -> re a <> re1 b EAccum _ _ _ a b e -> re a <> re b <> re e EZero _ _ e -> re e diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 6d70ca3..41da656 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -288,6 +288,10 @@ ppExpr' d val expr = case expr of ,e1' ,e2'] + ERecompute _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] + EWith _ t e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 1379e35..3c353d4 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -61,6 +61,7 @@ splitLets' = \sub -> \case EShape x e -> EShape x (splitLets' sub e) EOp x op e -> EOp x op (splitLets' sub e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) + ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 3d5f544..ac4d733 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -47,6 +47,7 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index a1a6376..4501c32 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -294,6 +294,10 @@ idana env expr = case expr of res <- genIds t4 pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') + ERecompute _ e -> do + (v, e') <- idana env e + pure (v, ERecompute v e') + EWith _ t e1 e2 -> do let t1 = typeOf e1 (_, e1') <- idana env e1 diff --git a/src/CHAD.hs b/src/CHAD.hs index 3a7b907..df792ce 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -62,14 +62,14 @@ tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds +bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape BTop SETop _ = ENil ext +bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = EPair ext (EVar ext t (w @> IZ)) - (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = - bindingsCollect binds sub (w .> WSink) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (BPush binds _) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) -- In order from large to small: i.e. in reverse order from what we want, -- because in a Bindings, the head of the list is the bottom-most entry. @@ -718,8 +718,8 @@ drev des accumMap = \case , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollect a0 subtapeA - , let collectB = bindingsCollect b0 subtapeB + , let collectA = bindingsCollectTape a0 subtapeA + , let collectB = bindingsCollectTape b0 subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 @@ -822,6 +822,29 @@ drev des accumMap = \case (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ weakenExpr (WCopy (WSink .> WSink)) b2) + -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 -> + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1) + (subenvCompose subMergeUsed sub) + (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @(D2 t)) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } + EError _ t s -> Ret BTop SETop @@ -849,7 +872,7 @@ drev des accumMap = \case case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollect e0 subtapeE in + let collectexpr = bindingsCollectTape e0 subtapeE in Ret (BTop `BPush` (shty, letBinds she0 she1) `BPush` (STArr ndim (STPair (d1 eltty) tapety) ,EBuild ext ndim diff --git a/src/Compile.hs b/src/Compile.hs index 9ed5a27..722b432 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -977,6 +977,8 @@ compile' env = \case maybe (return ()) ($ name2) mfun2 return (CELit name) + ERecompute _ e -> compile' env e + EWith _ t e1 e2 -> do actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index ebc70d7..a6d5ec8 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -185,6 +185,7 @@ dfwdDN = \case ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) + ERecompute _ e -> dfwdDN e EError _ t s -> EError ext (dn t) s EWith{} -> err_accum diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d7916d8..803a24a 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -153,6 +153,7 @@ interpret'Rec env = \case e1' <- interpret' env e1 e2' <- interpret' env e2 interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr + ERecompute _ e -> interpret' env e EWith _ t e1 e2 -> do initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> diff --git a/src/Language.hs b/src/Language.hs index 9fd5dd3..7a780a0 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -169,6 +169,9 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) with a (n :-> b) = NEWith (knownMTy @t) a n b diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 8bcb5e5..7e074df 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -71,6 +71,9 @@ data NExpr env t where -> NExpr env a -> NExpr env b -> NExpr env t + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil @@ -215,6 +218,7 @@ fromNamedExpr val = \case (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) diff --git a/src/Simplify.hs b/src/Simplify.hs index 6f97e6d..d963b7e 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -291,6 +291,7 @@ simplify'Rec = \case e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) pure (ECustom ext s t p a' b' c' e1' e2') + ERecompute _ e -> [simprec| ERecompute ext *e |] EWith _ t e1 e2 -> do e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) @@ -345,6 +346,7 @@ hasAdds = \case EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b + ERecompute _ e -> hasAdds e EAccum _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b |