diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 | 
| commit | 4fcdb7118e0084f192753ea6c70394352a27d5ed (patch) | |
| tree | c5e91ae438b6f4c3e5075bf591e5fbe28aa5d96b /src | |
| parent | 83692cf41f76272423445c9cbbad65561ee3b50c (diff) | |
Custom derivatives
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 15 | ||||
| -rw-r--r-- | src/AST/Count.hs | 2 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 27 | ||||
| -rw-r--r-- | src/CHAD.hs | 15 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 7 | ||||
| -rw-r--r-- | src/Interpreter.hs | 4 | ||||
| -rw-r--r-- | src/Simplify.hs | 13 | 
7 files changed, 53 insertions, 30 deletions
| @@ -26,7 +26,6 @@ import AST.Types  import AST.Weaken  import CHAD.Types  import Data -import ForwardAD.DualNumbers.Types  -- | This index is flipped around from the usual direction: the smallest index @@ -98,10 +97,14 @@ data Expr x env t where    EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t    -- custom derivatives -  ECustom :: x t -> STy a -> STy b +  -- 'b' is the part of the input of the operation that derivatives should +  -- be backpropagated to; 'a' is the inactive part. The dual field of +  -- ECustom does not allow a derivative to be generated for 'a', and hence +  -- none is propagated. +  ECustom :: x t -> STy a -> STy b -> STy tape            -> Expr x '[b, a] t  -- ^ regular operation -          -> Expr x '[DN b, a] (DN t)  -- ^ dual-numbers forward derivative -          -> Expr x '[D2 t, D1 b, D1 a] (D2 b)  -- ^ CHAD reverse derivative +          -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape)  -- ^ CHAD forward pass +          -> Expr x '[D2 t, tape] (D2 b)  -- ^ CHAD reverse derivative            -> Expr x env a -> Expr x env b            -> Expr x env t @@ -211,7 +214,7 @@ typeOf = \case    EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)    EOp _ op _ -> opt2 op -  ECustom _ _ _ e _ _ _ _ -> typeOf e +  ECustom _ _ _ _ e _ _ _ _ -> typeOf e    EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)    EAccum _ _ _ _ -> STNil @@ -285,7 +288,7 @@ subst' f w = \case    EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)    EShape x e -> EShape x (subst' f w e)    EOp x op e -> EOp x op (subst' f w e) -  ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2) +  ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)    EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)    EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)    EZero t -> EZero t diff --git a/src/AST/Count.hs b/src/AST/Count.hs index d365218..364773a 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -125,7 +125,7 @@ occCountGeneral onehot unpush alter many = go WId        EIdx _ a b -> re a <> re b        EShape _ e -> re e        EOp _ _ e -> re e -      ECustom _ _ _ _ _ _ a b -> re a <> re b +      ECustom _ _ _ _ _ _ _ a b -> re a <> re b        EWith a b -> re a <> re1 b        EAccum _ a b e -> re a <> re b <> re e        EZero _ -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index bac267d..a2232ee 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -16,7 +16,6 @@ import AST  import AST.Count  import CHAD.Types  import Data -import ForwardAD.DualNumbers.Types  type SVal = SList (Const String) @@ -177,20 +176,24 @@ ppExpr' d val = \case                  (Prefix, s) -> s      return $ showParen (d > 10) $ showString (ops ++ " ") . e' -  ECustom _ t1 t2 a b c e1 e2 -> do -    pn1 <- genNameIfUsedIn t1 (IS IZ) a -    pn2 <- genNameIfUsedIn t2 IZ a -    fn1 <- genNameIfUsedIn t1 (IS IZ) b -    fn2 <- genNameIfUsedIn (dn t2) IZ b -    rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c -    rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c -    rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c +  ECustom _ t1 t2 t3 a b c e1 e2 -> do +    en1 <- genNameIfUsedIn t1 (IS IZ) a +    en2 <- genNameIfUsedIn t2 IZ a +    pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b +    pn2 <- genNameIfUsedIn (d1 t2) IZ b +    dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c +    dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c      a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a -    b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b -    c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c +    b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b +    c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2 -    return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2' +    return $ showParen (d > 10) $ showString "custom " +      . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") " +      . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") " +      . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") " +      . e1' . showString " " +      . e2'    EWith e1 e2 -> do      e1' <- ppExpr' 11 val e1 diff --git a/src/CHAD.hs b/src/CHAD.hs index 8080ec0..2f05807 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -874,6 +874,20 @@ drev des = \case                                 (EVar ext (d2 (opt2 op)) IZ))                 (weakenExpr (WCopy (wSinks' @[_,_])) e2)) +  ECustom _ _ _ storety _ pr du a b +    -- allowed to ignore a2 because 'a' is the part of the input that is inactive +    | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) +        <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil -> +    Ret (binds `BPush` (typeOf a1, a1) +               `BPush` (typeOf b1, weakenExpr WSink b1) +               `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr) +               `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) +        (SEYes (SENo (SENo (SENo subtape)))) +        (EFst ext (EVar ext (typeOf pr) (IS IZ))) +        bsub +        (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $ +           weakenExpr (WCopy (WSink .> WSink)) b2) +    EError t s ->      Ret BTop          SETop @@ -888,7 +902,6 @@ drev des = \case          (subenvNone (select SMerge des))          (ENil ext) -  -- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2    EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)      | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she  -- allowed to ignore she2 here because she has a discrete result      , let eltty = typeOf orige diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 3587378..f2ded6e 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -144,11 +144,10 @@ dfwdDN = \case      | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)      -> EShape ext (dfwdDN e)    EOp _ op e -> dop op (dfwdDN e) -  ECustom _ s t _ du _ e1 e2 -> -    -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code. -    ELet ext (_ e1) $ +  ECustom _ _ _ _ pr _ _ e1 e2 -> +    ELet ext (dfwdDN e1) $      ELet ext (weakenExpr WSink (dfwdDN e2)) $ -      weakenExpr (WCopy (WCopy WClosed)) du +      weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)    EError t s -> EError (dn t) s    EWith{} -> err_accum diff --git a/src/Interpreter.hs b/src/Interpreter.hs index abc9800..47514ae 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -118,6 +118,10 @@ interpret'Rec env = \case      -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)    EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e    EOp _ op e -> interpretOp op <$> interpret' env e +  ECustom _ _ _ _ pr _ _ e1 e2 -> do +    e1' <- interpret' env e1 +    e2' <- interpret' env e2 +    interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr    EWith e1 e2 -> do      initval <- interpret' env e1      withAccum (typeOf e1) (typeOf e2) initval $ \accum -> diff --git a/src/Simplify.hs b/src/Simplify.hs index d3ee03c..e32ba8c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -121,11 +121,12 @@ simplify' = \case    EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b    EShape _ e -> EShape ext <$> simplify' e    EOp _ op e -> EOp ext op <$> simplify' e -  ECustom _ s t a b c e1 e2 -> -    ECustom ext s t <$> (let ?accumInScope = False in simplify' a) -                    <*> (let ?accumInScope = False in simplify' b) -                    <*> (let ?accumInScope = False in simplify' c) -                    <*> simplify' e1 <*> simplify' e2 +  ECustom _ s t p a b c e1 e2 -> +    ECustom ext s t p +      <$> (let ?accumInScope = False in simplify' a) +      <*> (let ?accumInScope = False in simplify' b) +      <*> (let ?accumInScope = False in simplify' c) +      <*> simplify' e1 <*> simplify' e2    EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)    EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3    EZero t -> pure $ EZero t @@ -165,7 +166,7 @@ hasAdds = \case    ESum1Inner _ e -> hasAdds e    EUnit _ e -> hasAdds e    EReplicate1Inner _ a b -> hasAdds a || hasAdds b -  ECustom _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e +  ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e    EConst _ _ _ -> False    EIdx0 _ e -> hasAdds e    EIdx1 _ a b -> hasAdds a || hasAdds b | 
