summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs62
1 files changed, 55 insertions, 7 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index cd4445e..26c918e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -11,13 +11,9 @@
{-# LANGUAGE TypeApplications #-}
module CHAD where
-import Data.Functor.Const
-
import AST
-type Ex = Expr (Const ())
-
data Bindings f env env' where
BTop :: Bindings f env env
BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env')
@@ -200,6 +196,48 @@ retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list)
(weakenExpr (WCopy (sinkWithBindings binds)) d))
pairs)
+d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
+d1op (OAdd t) e = EOp ext (OAdd t) e
+d1op (OMul t) e = EOp ext (OMul t) e
+d1op (ONeg t) e = EOp ext (ONeg t) e
+d1op (OLt t) e = EOp ext (OLt t) e
+d1op (OLe t) e = EOp ext (OLe t) e
+d1op (OEq t) e = EOp ext (OEq t) e
+d1op ONot e = EOp ext ONot e
+
+-- both primal and dual must be duplicable expressions
+d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)
+d2op op e d = case op of
+ OAdd _ -> EInr ext STNil (EPair ext d d)
+ OMul t -> d2opBinArrangeInt t $
+ EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d
+ OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ ONot -> ENil ext
+ where
+ d2opUnArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => Ex env (D2 (TScal a)))
+ -> Ex env (D2 (TScal a))
+ d2opUnArrangeInt ty float = case ty of
+ STI32 -> ENil ext
+ STI64 -> ENil ext
+ STF32 -> float
+ STF64 -> float
+ STBool -> ENil ext
+
+ d2opBinArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a))))
+ -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ d2opBinArrangeInt ty float = case ty of
+ STI32 -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI64 -> EInl ext (STPair STNil STNil) (ENil ext)
+ STF32 -> float
+ STF64 -> float
+ STBool -> EInl ext (STPair STNil STNil) (ENil ext)
+
drev :: SList STy env -> Ex env t -> Ret env t
drev senv = \case
EVar _ t i ->
@@ -309,7 +347,17 @@ drev senv = \case
(ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ))))))))
(weakenExpr (WCopy (wSinks @[_,_,_])) e2))
- _ -> undefined
+ EConst _ t val ->
+ Ret BTop
+ (EConst ext t val)
+ (EMReturn (d2e senv) (ENil ext))
+
+ EOp _ op e
+ | Ret e0 e1 e2 <- drev senv e ->
+ Ret (e0 `BPush` (d1 (typeOf e), e1))
+ (d1op op $ EVar ext (d1 (typeOf e)) IZ)
+ (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ))
+ (EVar ext (d2 (opt2 op)) IZ))
+ (weakenExpr (WCopy (wSinks @[_,_])) e2))
-ext :: Const () a
-ext = Const ()
+ e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)