summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs60
1 files changed, 36 insertions, 24 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index d0358b8..9a1c7d2 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -204,39 +204,44 @@ d1op (OLt t) e = EOp ext (OLt t) e
d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
+d1op OIf e = EOp ext OIf e
+
+data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
+ | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
-- both primal and dual must be duplicable expressions
-d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)
-d2op op e d = case op of
- OAdd _ -> EInr ext STNil (EPair ext d d)
- OMul t -> d2opBinArrangeInt t $
+d2op :: SOp a t -> D2Op a t
+d2op op = case op of
+ OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
(EOp ext (OMul t) (EPair ext (EFst ext e) d)))
- ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d
- OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- ONot -> ENil ext
+ ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
+ OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ ONot -> Linear $ \_ -> ENil ext
+ OIf -> Linear $ \_ -> ENil ext
where
d2opUnArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => Ex env (D2 (TScal a)))
- -> Ex env (D2 (TScal a))
+ -> (D2s a ~ TScal a => D2Op (TScal a) t)
+ -> D2Op (TScal a) t
d2opUnArrangeInt ty float = case ty of
- STI32 -> ENil ext
- STI64 -> ENil ext
+ STI32 -> Linear $ \_ -> ENil ext
+ STI64 -> Linear $ \_ -> ENil ext
STF32 -> float
STF64 -> float
- STBool -> ENil ext
+ STBool -> Linear $ \_ -> ENil ext
d2opBinArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a))))
- -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
+ -> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> EInl ext (STPair STNil STNil) (ENil ext)
- STI64 -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> EInl ext (STPair STNil STNil) (ENil ext)
+ STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
freezeRet :: Ret env t
-> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value
@@ -359,10 +364,17 @@ drev senv = \case
EOp _ op e
| Ret e0 e1 e2 <- drev senv e ->
- Ret (e0 `BPush` (d1 (typeOf e), e1))
- (d1op op $ EVar ext (d1 (typeOf e)) IZ)
- (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ))
- (EVar ext (d2 (opt2 op)) IZ))
- (weakenExpr (WCopy (wSinks @[_,_])) e2))
+ case d2op op of
+ Linear d2opfun ->
+ Ret e0
+ (d1op op e1)
+ (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))
+ (weakenExpr (WCopy WSink) e2))
+ Nonlinear d2opfun ->
+ Ret (e0 `BPush` (d1 (typeOf e), e1))
+ (d1op op $ EVar ext (d1 (typeOf e)) IZ)
+ (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
+ (EVar ext (d2 (opt2 op)) IZ))
+ (weakenExpr (WCopy (wSinks @[_,_])) e2))
e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)