summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-08 20:29:29 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-08 20:29:29 +0100
commit4fcdb7118e0084f192753ea6c70394352a27d5ed (patch)
treec5e91ae438b6f4c3e5075bf591e5fbe28aa5d96b
parent83692cf41f76272423445c9cbbad65561ee3b50c (diff)
Custom derivatives
-rw-r--r--src/AST.hs15
-rw-r--r--src/AST/Count.hs2
-rw-r--r--src/AST/Pretty.hs27
-rw-r--r--src/CHAD.hs15
-rw-r--r--src/ForwardAD/DualNumbers.hs7
-rw-r--r--src/Interpreter.hs4
-rw-r--r--src/Simplify.hs13
7 files changed, 53 insertions, 30 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 6bae84a..08a5bba 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -26,7 +26,6 @@ import AST.Types
import AST.Weaken
import CHAD.Types
import Data
-import ForwardAD.DualNumbers.Types
-- | This index is flipped around from the usual direction: the smallest index
@@ -98,10 +97,14 @@ data Expr x env t where
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-- custom derivatives
- ECustom :: x t -> STy a -> STy b
+ -- 'b' is the part of the input of the operation that derivatives should
+ -- be backpropagated to; 'a' is the inactive part. The dual field of
+ -- ECustom does not allow a derivative to be generated for 'a', and hence
+ -- none is propagated.
+ ECustom :: x t -> STy a -> STy b -> STy tape
-> Expr x '[b, a] t -- ^ regular operation
- -> Expr x '[DN b, a] (DN t) -- ^ dual-numbers forward derivative
- -> Expr x '[D2 t, D1 b, D1 a] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr x '[D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
-> Expr x env a -> Expr x env b
-> Expr x env t
@@ -211,7 +214,7 @@ typeOf = \case
EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
EOp _ op _ -> opt2 op
- ECustom _ _ _ e _ _ _ _ -> typeOf e
+ ECustom _ _ _ _ e _ _ _ _ -> typeOf e
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
EAccum _ _ _ _ -> STNil
@@ -285,7 +288,7 @@ subst' f w = \case
EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
- ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
EZero t -> EZero t
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index d365218..364773a 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -125,7 +125,7 @@ occCountGeneral onehot unpush alter many = go WId
EIdx _ a b -> re a <> re b
EShape _ e -> re e
EOp _ _ e -> re e
- ECustom _ _ _ _ _ _ a b -> re a <> re b
+ ECustom _ _ _ _ _ _ _ a b -> re a <> re b
EWith a b -> re a <> re1 b
EAccum _ a b e -> re a <> re b <> re e
EZero _ -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index bac267d..a2232ee 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -16,7 +16,6 @@ import AST
import AST.Count
import CHAD.Types
import Data
-import ForwardAD.DualNumbers.Types
type SVal = SList (Const String)
@@ -177,20 +176,24 @@ ppExpr' d val = \case
(Prefix, s) -> s
return $ showParen (d > 10) $ showString (ops ++ " ") . e'
- ECustom _ t1 t2 a b c e1 e2 -> do
- pn1 <- genNameIfUsedIn t1 (IS IZ) a
- pn2 <- genNameIfUsedIn t2 IZ a
- fn1 <- genNameIfUsedIn t1 (IS IZ) b
- fn2 <- genNameIfUsedIn (dn t2) IZ b
- rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c
- rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c
- rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
+ ECustom _ t1 t2 t3 a b c e1 e2 -> do
+ en1 <- genNameIfUsedIn t1 (IS IZ) a
+ en2 <- genNameIfUsedIn t2 IZ a
+ pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b
+ pn2 <- genNameIfUsedIn (d1 t2) IZ b
+ dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c
+ dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a
- b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b
- c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c
+ b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b
+ c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
- return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2'
+ return $ showParen (d > 10) $ showString "custom "
+ . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") "
+ . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") "
+ . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") "
+ . e1' . showString " "
+ . e2'
EWith e1 e2 -> do
e1' <- ppExpr' 11 val e1
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8080ec0..2f05807 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -874,6 +874,20 @@ drev des = \case
(EVar ext (d2 (opt2 op)) IZ))
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
+ ECustom _ _ _ storety _ pr du a b
+ -- allowed to ignore a2 because 'a' is the part of the input that is inactive
+ | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
+ <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil ->
+ Ret (binds `BPush` (typeOf a1, a1)
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr)
+ `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
+ (SEYes (SENo (SENo (SENo subtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
EError t s ->
Ret BTop
SETop
@@ -888,7 +902,6 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- -- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2
EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
| Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result
, let eltty = typeOf orige
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 3587378..f2ded6e 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -144,11 +144,10 @@ dfwdDN = \case
| Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
-> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
- ECustom _ s t _ du _ e1 e2 ->
- -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code.
- ELet ext (_ e1) $
+ ECustom _ _ _ _ pr _ _ e1 e2 ->
+ ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
- weakenExpr (WCopy (WCopy WClosed)) du
+ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
EError t s -> EError (dn t) s
EWith{} -> err_accum
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index abc9800..47514ae 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -118,6 +118,10 @@ interpret'Rec env = \case
-> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
+ ECustom _ _ _ _ pr _ _ e1 e2 -> do
+ e1' <- interpret' env e1
+ e2' <- interpret' env e2
+ interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
EWith e1 e2 -> do
initval <- interpret' env e1
withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
diff --git a/src/Simplify.hs b/src/Simplify.hs
index d3ee03c..e32ba8c 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -121,11 +121,12 @@ simplify' = \case
EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
EShape _ e -> EShape ext <$> simplify' e
EOp _ op e -> EOp ext op <$> simplify' e
- ECustom _ s t a b c e1 e2 ->
- ECustom ext s t <$> (let ?accumInScope = False in simplify' a)
- <*> (let ?accumInScope = False in simplify' b)
- <*> (let ?accumInScope = False in simplify' c)
- <*> simplify' e1 <*> simplify' e2
+ ECustom _ s t p a b c e1 e2 ->
+ ECustom ext s t p
+ <$> (let ?accumInScope = False in simplify' a)
+ <*> (let ?accumInScope = False in simplify' b)
+ <*> (let ?accumInScope = False in simplify' c)
+ <*> simplify' e1 <*> simplify' e2
EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
EZero t -> pure $ EZero t
@@ -165,7 +166,7 @@ hasAdds = \case
ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
- ECustom _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
+ ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b