diff options
author | Tom Smeding <tom@tomsmeding.com> | 2023-09-16 12:32:56 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2023-09-16 12:32:56 +0200 |
commit | 35cc10682f35dafba98000bf35191896a6432624 (patch) | |
tree | 7fbe9080ec502f3398a5329e15e5d263182d91a3 | |
parent | e52a3e7e89f6ad41d4291a467e4c1d3571614b0a (diff) |
CHAD ops
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 7 | ||||
-rw-r--r-- | src/CHAD.hs | 62 | ||||
-rw-r--r-- | src/Example.hs | 22 |
4 files changed, 85 insertions, 7 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index efcf44e..dd5bb27 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -14,6 +14,7 @@ library AST.Weaken CHAD -- Compile + Example PreludeCu other-modules: build-depends: @@ -13,6 +13,8 @@ {-# LANGUAGE DeriveFoldable #-} module AST (module AST, module AST.Weaken) where +import Data.Functor.Const + import Data.Kind (Type) import Data.Int @@ -126,6 +128,11 @@ data Expr x env t where EError :: STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) +type Ex = Expr (Const ()) + +ext :: Const () a +ext = Const () + type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) diff --git a/src/CHAD.hs b/src/CHAD.hs index cd4445e..26c918e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -11,13 +11,9 @@ {-# LANGUAGE TypeApplications #-} module CHAD where -import Data.Functor.Const - import AST -type Ex = Expr (Const ()) - data Bindings f env env' where BTop :: Bindings f env env BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env') @@ -200,6 +196,48 @@ retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list) (weakenExpr (WCopy (sinkWithBindings binds)) d)) pairs) +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e + +-- both primal and dual must be duplicable expressions +d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a) +d2op op e d = case op of + OAdd _ -> EInr ext STNil (EPair ext d d) + OMul t -> d2opBinArrangeInt t $ + EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d + OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + ONot -> ENil ext + where + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => Ex env (D2 (TScal a))) + -> Ex env (D2 (TScal a)) + d2opUnArrangeInt ty float = case ty of + STI32 -> ENil ext + STI64 -> ENil ext + STF32 -> float + STF64 -> float + STBool -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a)))) + -> Ex env (D2 (TPair (TScal a) (TScal a))) + d2opBinArrangeInt ty float = case ty of + STI32 -> EInl ext (STPair STNil STNil) (ENil ext) + STI64 -> EInl ext (STPair STNil STNil) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> EInl ext (STPair STNil STNil) (ENil ext) + drev :: SList STy env -> Ex env t -> Ret env t drev senv = \case EVar _ t i -> @@ -309,7 +347,17 @@ drev senv = \case (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ)))))))) (weakenExpr (WCopy (wSinks @[_,_,_])) e2)) - _ -> undefined + EConst _ t val -> + Ret BTop + (EConst ext t val) + (EMReturn (d2e senv) (ENil ext)) + + EOp _ op e + | Ret e0 e1 e2 <- drev senv e -> + Ret (e0 `BPush` (d1 (typeOf e), e1)) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ)) + (EVar ext (d2 (opt2 op)) IZ)) + (weakenExpr (WCopy (wSinks @[_,_])) e2)) -ext :: Const () a -ext = Const () + e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) diff --git a/src/Example.hs b/src/Example.hs new file mode 100644 index 0000000..89b2082 --- /dev/null +++ b/src/Example.hs @@ -0,0 +1,22 @@ +{-# LANGUAGE DataKinds #-} +module Example where + +import AST +import CHAD + + +bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c +bin op a b = EOp ext op (EPair ext a b) + +-- x y |- x * y + x +ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex1 = + bin (OAdd STF32) + (bin (OMul STF32) + (EVar ext (STScal STF32) (IS IZ)) + (EVar ext (STScal STF32) IZ)) + (EVar ext (STScal STF32) (IS IZ)) + +-- -- x y |- let z = x + y in z * (z + x) +-- ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +-- ex2 = _ |