diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 |
commit | 4fcdb7118e0084f192753ea6c70394352a27d5ed (patch) | |
tree | c5e91ae438b6f4c3e5075bf591e5fbe28aa5d96b | |
parent | 83692cf41f76272423445c9cbbad65561ee3b50c (diff) |
Custom derivatives
-rw-r--r-- | src/AST.hs | 15 | ||||
-rw-r--r-- | src/AST/Count.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 27 | ||||
-rw-r--r-- | src/CHAD.hs | 15 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 7 | ||||
-rw-r--r-- | src/Interpreter.hs | 4 | ||||
-rw-r--r-- | src/Simplify.hs | 13 |
7 files changed, 53 insertions, 30 deletions
@@ -26,7 +26,6 @@ import AST.Types import AST.Weaken import CHAD.Types import Data -import ForwardAD.DualNumbers.Types -- | This index is flipped around from the usual direction: the smallest index @@ -98,10 +97,14 @@ data Expr x env t where EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- custom derivatives - ECustom :: x t -> STy a -> STy b + -- 'b' is the part of the input of the operation that derivatives should + -- be backpropagated to; 'a' is the inactive part. The dual field of + -- ECustom does not allow a derivative to be generated for 'a', and hence + -- none is propagated. + ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x '[b, a] t -- ^ regular operation - -> Expr x '[DN b, a] (DN t) -- ^ dual-numbers forward derivative - -> Expr x '[D2 t, D1 b, D1 a] (D2 b) -- ^ CHAD reverse derivative + -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr x '[D2 t, tape] (D2 b) -- ^ CHAD reverse derivative -> Expr x env a -> Expr x env b -> Expr x env t @@ -211,7 +214,7 @@ typeOf = \case EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op - ECustom _ _ _ e _ _ _ _ -> typeOf e + ECustom _ _ _ _ e _ _ _ _ -> typeOf e EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ -> STNil @@ -285,7 +288,7 @@ subst' f w = \case EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) - ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero t -> EZero t diff --git a/src/AST/Count.hs b/src/AST/Count.hs index d365218..364773a 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -125,7 +125,7 @@ occCountGeneral onehot unpush alter many = go WId EIdx _ a b -> re a <> re b EShape _ e -> re e EOp _ _ e -> re e - ECustom _ _ _ _ _ _ a b -> re a <> re b + ECustom _ _ _ _ _ _ _ a b -> re a <> re b EWith a b -> re a <> re1 b EAccum _ a b e -> re a <> re b <> re e EZero _ -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index bac267d..a2232ee 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -16,7 +16,6 @@ import AST import AST.Count import CHAD.Types import Data -import ForwardAD.DualNumbers.Types type SVal = SList (Const String) @@ -177,20 +176,24 @@ ppExpr' d val = \case (Prefix, s) -> s return $ showParen (d > 10) $ showString (ops ++ " ") . e' - ECustom _ t1 t2 a b c e1 e2 -> do - pn1 <- genNameIfUsedIn t1 (IS IZ) a - pn2 <- genNameIfUsedIn t2 IZ a - fn1 <- genNameIfUsedIn t1 (IS IZ) b - fn2 <- genNameIfUsedIn (dn t2) IZ b - rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c - rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c - rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c + ECustom _ t1 t2 t3 a b c e1 e2 -> do + en1 <- genNameIfUsedIn t1 (IS IZ) a + en2 <- genNameIfUsedIn t2 IZ a + pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b + pn2 <- genNameIfUsedIn (d1 t2) IZ b + dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c + dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a - b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b - c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c + b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b + c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 - return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2' + return $ showParen (d > 10) $ showString "custom " + . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") " + . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") " + . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") " + . e1' . showString " " + . e2' EWith e1 e2 -> do e1' <- ppExpr' 11 val e1 diff --git a/src/CHAD.hs b/src/CHAD.hs index 8080ec0..2f05807 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -874,6 +874,20 @@ drev des = \case (EVar ext (d2 (opt2 op)) IZ)) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + ECustom _ _ _ storety _ pr du a b + -- allowed to ignore a2 because 'a' is the part of the input that is inactive + | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) + <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil -> + Ret (binds `BPush` (typeOf a1, a1) + `BPush` (typeOf b1, weakenExpr WSink b1) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr) + `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) + (SEYes (SENo (SENo (SENo subtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + EError t s -> Ret BTop SETop @@ -888,7 +902,6 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - -- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2 EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result , let eltty = typeOf orige diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 3587378..f2ded6e 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -144,11 +144,10 @@ dfwdDN = \case | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) - ECustom _ s t _ du _ e1 e2 -> - -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code. - ELet ext (_ e1) $ + ECustom _ _ _ _ pr _ _ e1 e2 -> + ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ - weakenExpr (WCopy (WCopy WClosed)) du + weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) EError t s -> EError (dn t) s EWith{} -> err_accum diff --git a/src/Interpreter.hs b/src/Interpreter.hs index abc9800..47514ae 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -118,6 +118,10 @@ interpret'Rec env = \case -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e + ECustom _ _ _ _ pr _ _ e1 e2 -> do + e1' <- interpret' env e1 + e2' <- interpret' env e2 + interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr EWith e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) (typeOf e2) initval $ \accum -> diff --git a/src/Simplify.hs b/src/Simplify.hs index d3ee03c..e32ba8c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -121,11 +121,12 @@ simplify' = \case EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b EShape _ e -> EShape ext <$> simplify' e EOp _ op e -> EOp ext op <$> simplify' e - ECustom _ s t a b c e1 e2 -> - ECustom ext s t <$> (let ?accumInScope = False in simplify' a) - <*> (let ?accumInScope = False in simplify' b) - <*> (let ?accumInScope = False in simplify' c) - <*> simplify' e1 <*> simplify' e2 + ECustom _ s t p a b c e1 e2 -> + ECustom ext s t p + <$> (let ?accumInScope = False in simplify' a) + <*> (let ?accumInScope = False in simplify' b) + <*> (let ?accumInScope = False in simplify' c) + <*> simplify' e1 <*> simplify' e2 EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 EZero t -> pure $ EZero t @@ -165,7 +166,7 @@ hasAdds = \case ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b - ECustom _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e + ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b |