diff options
| -rw-r--r-- | src/AST.hs | 7 | ||||
| -rw-r--r-- | src/AST/Bindings.hs | 11 | ||||
| -rw-r--r-- | src/AST/Count.hs | 1 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 4 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 1 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 1 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 4 | ||||
| -rw-r--r-- | src/CHAD.hs | 41 | ||||
| -rw-r--r-- | src/Compile.hs | 2 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 1 | ||||
| -rw-r--r-- | src/Interpreter.hs | 1 | ||||
| -rw-r--r-- | src/Language.hs | 3 | ||||
| -rw-r--r-- | src/Language/AST.hs | 4 | ||||
| -rw-r--r-- | src/Simplify.hs | 2 | 
14 files changed, 74 insertions, 9 deletions
| @@ -87,6 +87,9 @@ data Expr x env t where            -> Expr x env a -> Expr x env b            -> Expr x env t +  -- fake halfway checkpointing +  ERecompute :: x t -> Expr x env t -> Expr x env t +    -- accumulation effect on monoids    EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)    EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil @@ -212,6 +215,7 @@ typeOf = \case    EOp _ op _ -> opt2 op    ECustom _ _ _ _ e _ _ _ _ -> typeOf e +  ERecompute _ e -> typeOf e    EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)    EAccum _ _ _ _ _ _ -> STNil @@ -255,6 +259,7 @@ extOf = \case    EShape x _ -> x    EOp x _ _ -> x    ECustom x _ _ _ _ _ _ _ _ -> x +  ERecompute x _ -> x    EWith x _ _ _ -> x    EAccum x _ _ _ _ _ -> x    EZero x _ _ -> x @@ -299,6 +304,7 @@ travExt f = \case    EShape x e -> EShape <$> f x <*> travExt f e    EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e    ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 +  ERecompute x e -> ERecompute <$> f x <*> travExt f e    EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2    EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3    EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e @@ -356,6 +362,7 @@ subst' f w = \case    EShape x e -> EShape x (subst' f w e)    EOp x op e -> EOp x op (subst' f w e)    ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) +  ERecompute x e -> ERecompute x (subst' f w e)    EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)    EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)    EZero x t e -> EZero x t (subst' f w e) diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 3d99afe..745a93b 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -16,6 +16,7 @@  module AST.Bindings where  import AST +import AST.Env  import Data  import Lemmas @@ -62,3 +63,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)  letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t  letBinds BTop = id  letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + +collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' +collectBindings = \env -> fst . go env WId +  where +    go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) +    go _ _ SETop = (BTop, WId) +    go (ty `SCons` env) w (SEYes sub) = +      let (bs, w') = go env (WPop w) sub +      in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') +    go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index feaaa1e..0c682c6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -132,6 +132,7 @@ occCountGeneral onehot unpush alter many = go WId        EShape _ e -> re e        EOp _ _ e -> re e        ECustom _ _ _ _ _ _ _ a b -> re a <> re b +      ERecompute _ e -> re e        EWith _ _ a b -> re a <> re1 b        EAccum _ _ _ a b e -> re a <> re b <> re e        EZero _ _ e -> re e diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 6d70ca3..41da656 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -288,6 +288,10 @@ ppExpr' d val expr = case expr of          ,e1'          ,e2'] +  ERecompute _ e -> do +    e' <- ppExpr' 11 val e +    return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] +    EWith _ t e1 e2 -> do      e1' <- ppExpr' 11 val e1      name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 1379e35..3c353d4 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -61,6 +61,7 @@ splitLets' = \sub -> \case    EShape x e -> EShape x (splitLets' sub e)    EOp x op e -> EOp x op (splitLets' sub e)    ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) +  ERecompute x e -> ERecompute x (splitLets' sub e)    EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)    EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)    EZero x t ezi -> EZero x t (splitLets' sub ezi) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 3d5f544..ac4d733 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -47,6 +47,7 @@ unMonoid = \case    EShape _ e -> EShape ext (unMonoid e)    EOp _ op e -> EOp ext op (unMonoid e)    ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) +  ERecompute _ e -> ERecompute ext (unMonoid e)    EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)    EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)    EError _ t s -> EError ext t s diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index a1a6376..4501c32 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -294,6 +294,10 @@ idana env expr = case expr of      res <- genIds t4      pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') +  ERecompute _ e -> do +    (v, e') <- idana env e +    pure (v, ERecompute v e') +    EWith _ t e1 e2 -> do      let t1 = typeOf e1      (_, e1') <- idana env e1 diff --git a/src/CHAD.hs b/src/CHAD.hs index 3a7b907..df792ce 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -62,14 +62,14 @@ tapeTy :: SList STy binds -> STy (Tape binds)  tapeTy SNil = STNil  tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds +bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds                  -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape BTop SETop _ = ENil ext +bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w =    EPair ext (EVar ext t (w @> IZ)) -            (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = -  bindingsCollect binds sub (w .> WSink) +            (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (BPush binds _) (SENo sub) w = +  bindingsCollectTape binds sub (w .> WSink)  -- In order from large to small: i.e. in reverse order from what we want,  -- because in a Bindings, the head of the list is the bottom-most entry. @@ -718,8 +718,8 @@ drev des accumMap = \case      , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)      , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) -    , let collectA = bindingsCollect a0 subtapeA -    , let collectB = bindingsCollect b0 subtapeB +    , let collectA = bindingsCollectTape a0 subtapeA +    , let collectB = bindingsCollectTape b0 subtapeB      , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)      , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0      , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 @@ -822,6 +822,29 @@ drev des accumMap = \case          (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $             weakenExpr (WCopy (WSink .> WSink)) b2) +  -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal +  ERecompute _ e -> +    deleteUnused (descrList des) (occCountAll e) $ \usedSub -> +    let smallE = unsafeWeakenWithSubenv usedSub e in +    subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> +    case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 -> +    Ret (collectBindings (desD1E des) subD1eUsed) +        (subenvAll (desD1E usedDes)) +        (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1) +        (subenvCompose subMergeUsed sub) +        (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ +           weakenExpr +             (autoWeak (#d (auto1 @(D2 t)) +                        &. #shbinds (bindingsBinds e0) +                        &. #tape (subList (bindingsBinds e0) subtape) +                        &. #d1env (desD1E usedDes) +                        &. #tl' (d2ace (select SAccum usedDes)) +                        &. #tl (d2ace (select SAccum des))) +                       (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) +                       (#shbinds :++: #d :++: #d1env :++: #tl)) +             e2) +    } +    EError _ t s ->      Ret BTop          SETop @@ -849,7 +872,7 @@ drev des accumMap = \case      case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->      case assertSubenvEmpty sub of { Refl ->      let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in -    let collectexpr = bindingsCollect e0 subtapeE in +    let collectexpr = bindingsCollectTape e0 subtapeE in      Ret (BTop `BPush` (shty, letBinds she0 she1)                `BPush` (STArr ndim (STPair (d1 eltty) tapety)                        ,EBuild ext ndim diff --git a/src/Compile.hs b/src/Compile.hs index 9ed5a27..722b432 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -977,6 +977,8 @@ compile' env = \case          maybe (return ()) ($ name2) mfun2          return (CELit name) +  ERecompute _ e -> compile' env e +    EWith _ t e1 e2 -> do      actyname <- emitStruct (STAccum t)      name1 <- compileAssign "" env e1 diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index ebc70d7..a6d5ec8 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -185,6 +185,7 @@ dfwdDN = \case      ELet ext (dfwdDN e1) $      ELet ext (weakenExpr WSink (dfwdDN e2)) $        weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) +  ERecompute _ e -> dfwdDN e    EError _ t s -> EError ext (dn t) s    EWith{} -> err_accum diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d7916d8..803a24a 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -153,6 +153,7 @@ interpret'Rec env = \case      e1' <- interpret' env e1      e2' <- interpret' env e2      interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr +  ERecompute _ e -> interpret' env e    EWith _ t e1 e2 -> do      initval <- interpret' env e1      withAccum t (typeOf e2) initval $ \accum -> diff --git a/src/Language.hs b/src/Language.hs index 9fd5dd3..7a780a0 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -169,6 +169,9 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)  custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =    NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute +  with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)  with a (n :-> b) = NEWith (knownMTy @t) a n b diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 8bcb5e5..7e074df 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -71,6 +71,9 @@ data NExpr env t where             -> NExpr env a -> NExpr env b             -> NExpr env t +  -- fake halfway checkpointing +  NERecompute :: NExpr env t -> NExpr env t +    -- accumulation effect on monoids    NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)    NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil @@ -215,6 +218,7 @@ fromNamedExpr val = \case              (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)              (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)              (go e1) (go e2) +  NERecompute e -> ERecompute ext (go e)    NEWith t a n b -> EWith ext t (go a) (lambda val n b)    NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) diff --git a/src/Simplify.hs b/src/Simplify.hs index 6f97e6d..d963b7e 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -291,6 +291,7 @@ simplify'Rec = \case      e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)      e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)      pure (ECustom ext s t p a' b' c' e1' e2') +  ERecompute _ e -> [simprec| ERecompute ext *e |]    EWith _ t e1 e2 -> do      e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)      e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) @@ -345,6 +346,7 @@ hasAdds = \case    EShape _ e -> hasAdds e    EOp _ _ e -> hasAdds e    EWith _ _ a b -> hasAdds a || hasAdds b +  ERecompute _ e -> hasAdds e    EAccum _ _ _ _ _ _ -> True    EZero _ _ e -> hasAdds e    EPlus _ _ a b -> hasAdds a || hasAdds b | 
