diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 60 |
1 files changed, 36 insertions, 24 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index d0358b8..9a1c7d2 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -204,39 +204,44 @@ d1op (OLt t) e = EOp ext (OLt t) e d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e +d1op OIf e = EOp ext OIf e + +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) + | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) -- both primal and dual must be duplicable expressions -d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a) -d2op op e d = case op of - OAdd _ -> EInr ext STNil (EPair ext d d) - OMul t -> d2opBinArrangeInt t $ +d2op :: SOp a t -> D2Op a t +d2op op = case op of + OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d) + OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) (EOp ext (OMul t) (EPair ext (EFst ext e) d))) - ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d - OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - ONot -> ENil ext + ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d + OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + ONot -> Linear $ \_ -> ENil ext + OIf -> Linear $ \_ -> ENil ext where d2opUnArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => Ex env (D2 (TScal a))) - -> Ex env (D2 (TScal a)) + -> (D2s a ~ TScal a => D2Op (TScal a) t) + -> D2Op (TScal a) t d2opUnArrangeInt ty float = case ty of - STI32 -> ENil ext - STI64 -> ENil ext + STI32 -> Linear $ \_ -> ENil ext + STI64 -> Linear $ \_ -> ENil ext STF32 -> float STF64 -> float - STBool -> ENil ext + STBool -> Linear $ \_ -> ENil ext d2opBinArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a)))) - -> Ex env (D2 (TPair (TScal a) (TScal a))) + -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) + -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> EInl ext (STPair STNil STNil) (ENil ext) - STI64 -> EInl ext (STPair STNil STNil) (ENil ext) + STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) STF32 -> float STF64 -> float - STBool -> EInl ext (STPair STNil STNil) (ENil ext) + STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) freezeRet :: Ret env t -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value @@ -359,10 +364,17 @@ drev senv = \case EOp _ op e | Ret e0 e1 e2 <- drev senv e -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) - (d1op op $ EVar ext (d1 (typeOf e)) IZ) - (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy (wSinks @[_,_])) e2)) + case d2op op of + Linear d2opfun -> + Ret e0 + (d1op op e1) + (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) + (weakenExpr (WCopy WSink) e2)) + Nonlinear d2opfun -> + Ret (e0 `BPush` (d1 (typeOf e), e1)) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) + (EVar ext (d2 (opt2 op)) IZ)) + (weakenExpr (WCopy (wSinks @[_,_])) e2)) e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) |