From 4fcdb7118e0084f192753ea6c70394352a27d5ed Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Fri, 8 Nov 2024 20:29:29 +0100
Subject: Custom derivatives

---
 src/AST.hs                   | 15 +++++++++------
 src/AST/Count.hs             |  2 +-
 src/AST/Pretty.hs            | 27 +++++++++++++++------------
 src/CHAD.hs                  | 15 ++++++++++++++-
 src/ForwardAD/DualNumbers.hs |  7 +++----
 src/Interpreter.hs           |  4 ++++
 src/Simplify.hs              | 13 +++++++------
 7 files changed, 53 insertions(+), 30 deletions(-)

(limited to 'src')

diff --git a/src/AST.hs b/src/AST.hs
index 6bae84a..08a5bba 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -26,7 +26,6 @@ import AST.Types
 import AST.Weaken
 import CHAD.Types
 import Data
-import ForwardAD.DualNumbers.Types
 
 
 -- | This index is flipped around from the usual direction: the smallest index
@@ -98,10 +97,14 @@ data Expr x env t where
   EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
 
   -- custom derivatives
-  ECustom :: x t -> STy a -> STy b
+  -- 'b' is the part of the input of the operation that derivatives should
+  -- be backpropagated to; 'a' is the inactive part. The dual field of
+  -- ECustom does not allow a derivative to be generated for 'a', and hence
+  -- none is propagated.
+  ECustom :: x t -> STy a -> STy b -> STy tape
           -> Expr x '[b, a] t  -- ^ regular operation
-          -> Expr x '[DN b, a] (DN t)  -- ^ dual-numbers forward derivative
-          -> Expr x '[D2 t, D1 b, D1 a] (D2 b)  -- ^ CHAD reverse derivative
+          -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape)  -- ^ CHAD forward pass
+          -> Expr x '[D2 t, tape] (D2 b)  -- ^ CHAD reverse derivative
           -> Expr x env a -> Expr x env b
           -> Expr x env t
 
@@ -211,7 +214,7 @@ typeOf = \case
   EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
   EOp _ op _ -> opt2 op
 
-  ECustom _ _ _ e _ _ _ _ -> typeOf e
+  ECustom _ _ _ _ e _ _ _ _ -> typeOf e
 
   EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
   EAccum _ _ _ _ -> STNil
@@ -285,7 +288,7 @@ subst' f w = \case
   EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
   EShape x e -> EShape x (subst' f w e)
   EOp x op e -> EOp x op (subst' f w e)
-  ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2)
+  ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
   EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
   EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
   EZero t -> EZero t
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index d365218..364773a 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -125,7 +125,7 @@ occCountGeneral onehot unpush alter many = go WId
       EIdx _ a b -> re a <> re b
       EShape _ e -> re e
       EOp _ _ e -> re e
-      ECustom _ _ _ _ _ _ a b -> re a <> re b
+      ECustom _ _ _ _ _ _ _ a b -> re a <> re b
       EWith a b -> re a <> re1 b
       EAccum _ a b e -> re a <> re b <> re e
       EZero _ -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index bac267d..a2232ee 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -16,7 +16,6 @@ import AST
 import AST.Count
 import CHAD.Types
 import Data
-import ForwardAD.DualNumbers.Types
 
 
 type SVal = SList (Const String)
@@ -177,20 +176,24 @@ ppExpr' d val = \case
                 (Prefix, s) -> s
     return $ showParen (d > 10) $ showString (ops ++ " ") . e'
 
-  ECustom _ t1 t2 a b c e1 e2 -> do
-    pn1 <- genNameIfUsedIn t1 (IS IZ) a
-    pn2 <- genNameIfUsedIn t2 IZ a
-    fn1 <- genNameIfUsedIn t1 (IS IZ) b
-    fn2 <- genNameIfUsedIn (dn t2) IZ b
-    rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c
-    rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c
-    rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
+  ECustom _ t1 t2 t3 a b c e1 e2 -> do
+    en1 <- genNameIfUsedIn t1 (IS IZ) a
+    en2 <- genNameIfUsedIn t2 IZ a
+    pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b
+    pn2 <- genNameIfUsedIn (d1 t2) IZ b
+    dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c
+    dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
     a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a
-    b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b
-    c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c
+    b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b
+    c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c
     e1' <- ppExpr' 11 val e1
     e2' <- ppExpr' 11 val e2
-    return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2'
+    return $ showParen (d > 10) $ showString "custom "
+      . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") "
+      . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") "
+      . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") "
+      . e1' . showString " "
+      . e2'
 
   EWith e1 e2 -> do
     e1' <- ppExpr' 11 val e1
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8080ec0..2f05807 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -874,6 +874,20 @@ drev des = \case
                                (EVar ext (d2 (opt2 op)) IZ))
                (weakenExpr (WCopy (wSinks' @[_,_])) e2))
 
+  ECustom _ _ _ storety _ pr du a b
+    -- allowed to ignore a2 because 'a' is the part of the input that is inactive
+    | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
+        <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil ->
+    Ret (binds `BPush` (typeOf a1, a1)
+               `BPush` (typeOf b1, weakenExpr WSink b1)
+               `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr)
+               `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
+        (SEYes (SENo (SENo (SENo subtape))))
+        (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+        bsub
+        (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $
+           weakenExpr (WCopy (WSink .> WSink)) b2)
+
   EError t s ->
     Ret BTop
         SETop
@@ -888,7 +902,6 @@ drev des = \case
         (subenvNone (select SMerge des))
         (ENil ext)
 
-  -- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2
   EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
     | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she  -- allowed to ignore she2 here because she has a discrete result
     , let eltty = typeOf orige
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 3587378..f2ded6e 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -144,11 +144,10 @@ dfwdDN = \case
     | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
     -> EShape ext (dfwdDN e)
   EOp _ op e -> dop op (dfwdDN e)
-  ECustom _ s t _ du _ e1 e2 ->
-    -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code.
-    ELet ext (_ e1) $
+  ECustom _ _ _ _ pr _ _ e1 e2 ->
+    ELet ext (dfwdDN e1) $
     ELet ext (weakenExpr WSink (dfwdDN e2)) $
-      weakenExpr (WCopy (WCopy WClosed)) du
+      weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
   EError t s -> EError (dn t) s
 
   EWith{} -> err_accum
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index abc9800..47514ae 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -118,6 +118,10 @@ interpret'Rec env = \case
     -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
   EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
   EOp _ op e -> interpretOp op <$> interpret' env e
+  ECustom _ _ _ _ pr _ _ e1 e2 -> do
+    e1' <- interpret' env e1
+    e2' <- interpret' env e2
+    interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
   EWith e1 e2 -> do
     initval <- interpret' env e1
     withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
diff --git a/src/Simplify.hs b/src/Simplify.hs
index d3ee03c..e32ba8c 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -121,11 +121,12 @@ simplify' = \case
   EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
   EShape _ e -> EShape ext <$> simplify' e
   EOp _ op e -> EOp ext op <$> simplify' e
-  ECustom _ s t a b c e1 e2 ->
-    ECustom ext s t <$> (let ?accumInScope = False in simplify' a)
-                    <*> (let ?accumInScope = False in simplify' b)
-                    <*> (let ?accumInScope = False in simplify' c)
-                    <*> simplify' e1 <*> simplify' e2
+  ECustom _ s t p a b c e1 e2 ->
+    ECustom ext s t p
+      <$> (let ?accumInScope = False in simplify' a)
+      <*> (let ?accumInScope = False in simplify' b)
+      <*> (let ?accumInScope = False in simplify' c)
+      <*> simplify' e1 <*> simplify' e2
   EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
   EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
   EZero t -> pure $ EZero t
@@ -165,7 +166,7 @@ hasAdds = \case
   ESum1Inner _ e -> hasAdds e
   EUnit _ e -> hasAdds e
   EReplicate1Inner _ a b -> hasAdds a || hasAdds b
-  ECustom _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
+  ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
   EConst _ _ _ -> False
   EIdx0 _ e -> hasAdds e
   EIdx1 _ a b -> hasAdds a || hasAdds b
-- 
cgit v1.2.3-70-g09d2