diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 62 |
1 files changed, 55 insertions, 7 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index cd4445e..26c918e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -11,13 +11,9 @@ {-# LANGUAGE TypeApplications #-} module CHAD where -import Data.Functor.Const - import AST -type Ex = Expr (Const ()) - data Bindings f env env' where BTop :: Bindings f env env BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env') @@ -200,6 +196,48 @@ retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list) (weakenExpr (WCopy (sinkWithBindings binds)) d)) pairs) +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e + +-- both primal and dual must be duplicable expressions +d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a) +d2op op e d = case op of + OAdd _ -> EInr ext STNil (EPair ext d d) + OMul t -> d2opBinArrangeInt t $ + EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d + OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + ONot -> ENil ext + where + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => Ex env (D2 (TScal a))) + -> Ex env (D2 (TScal a)) + d2opUnArrangeInt ty float = case ty of + STI32 -> ENil ext + STI64 -> ENil ext + STF32 -> float + STF64 -> float + STBool -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a)))) + -> Ex env (D2 (TPair (TScal a) (TScal a))) + d2opBinArrangeInt ty float = case ty of + STI32 -> EInl ext (STPair STNil STNil) (ENil ext) + STI64 -> EInl ext (STPair STNil STNil) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> EInl ext (STPair STNil STNil) (ENil ext) + drev :: SList STy env -> Ex env t -> Ret env t drev senv = \case EVar _ t i -> @@ -309,7 +347,17 @@ drev senv = \case (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ)))))))) (weakenExpr (WCopy (wSinks @[_,_,_])) e2)) - _ -> undefined + EConst _ t val -> + Ret BTop + (EConst ext t val) + (EMReturn (d2e senv) (ENil ext)) + + EOp _ op e + | Ret e0 e1 e2 <- drev senv e -> + Ret (e0 `BPush` (d1 (typeOf e), e1)) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ)) + (EVar ext (d2 (opt2 op)) IZ)) + (weakenExpr (WCopy (wSinks @[_,_])) e2)) -ext :: Const () a -ext = Const () + e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) |