summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2023-09-16 12:32:56 +0200
committerTom Smeding <tom@tomsmeding.com>2023-09-16 12:32:56 +0200
commit35cc10682f35dafba98000bf35191896a6432624 (patch)
tree7fbe9080ec502f3398a5329e15e5d263182d91a3
parente52a3e7e89f6ad41d4291a467e4c1d3571614b0a (diff)
CHAD ops
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs7
-rw-r--r--src/CHAD.hs62
-rw-r--r--src/Example.hs22
4 files changed, 85 insertions, 7 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index efcf44e..dd5bb27 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -14,6 +14,7 @@ library
AST.Weaken
CHAD
-- Compile
+ Example
PreludeCu
other-modules:
build-depends:
diff --git a/src/AST.hs b/src/AST.hs
index b1f3e5d..7c5de11 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -13,6 +13,8 @@
{-# LANGUAGE DeriveFoldable #-}
module AST (module AST, module AST.Weaken) where
+import Data.Functor.Const
+
import Data.Kind (Type)
import Data.Int
@@ -126,6 +128,11 @@ data Expr x env t where
EError :: STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+type Ex = Expr (Const ())
+
+ext :: Const () a
+ext = Const ()
+
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index cd4445e..26c918e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -11,13 +11,9 @@
{-# LANGUAGE TypeApplications #-}
module CHAD where
-import Data.Functor.Const
-
import AST
-type Ex = Expr (Const ())
-
data Bindings f env env' where
BTop :: Bindings f env env
BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env')
@@ -200,6 +196,48 @@ retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list)
(weakenExpr (WCopy (sinkWithBindings binds)) d))
pairs)
+d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
+d1op (OAdd t) e = EOp ext (OAdd t) e
+d1op (OMul t) e = EOp ext (OMul t) e
+d1op (ONeg t) e = EOp ext (ONeg t) e
+d1op (OLt t) e = EOp ext (OLt t) e
+d1op (OLe t) e = EOp ext (OLe t) e
+d1op (OEq t) e = EOp ext (OEq t) e
+d1op ONot e = EOp ext ONot e
+
+-- both primal and dual must be duplicable expressions
+d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)
+d2op op e d = case op of
+ OAdd _ -> EInr ext STNil (EPair ext d d)
+ OMul t -> d2opBinArrangeInt t $
+ EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d
+ OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ ONot -> ENil ext
+ where
+ d2opUnArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => Ex env (D2 (TScal a)))
+ -> Ex env (D2 (TScal a))
+ d2opUnArrangeInt ty float = case ty of
+ STI32 -> ENil ext
+ STI64 -> ENil ext
+ STF32 -> float
+ STF64 -> float
+ STBool -> ENil ext
+
+ d2opBinArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a))))
+ -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ d2opBinArrangeInt ty float = case ty of
+ STI32 -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI64 -> EInl ext (STPair STNil STNil) (ENil ext)
+ STF32 -> float
+ STF64 -> float
+ STBool -> EInl ext (STPair STNil STNil) (ENil ext)
+
drev :: SList STy env -> Ex env t -> Ret env t
drev senv = \case
EVar _ t i ->
@@ -309,7 +347,17 @@ drev senv = \case
(ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ))))))))
(weakenExpr (WCopy (wSinks @[_,_,_])) e2))
- _ -> undefined
+ EConst _ t val ->
+ Ret BTop
+ (EConst ext t val)
+ (EMReturn (d2e senv) (ENil ext))
+
+ EOp _ op e
+ | Ret e0 e1 e2 <- drev senv e ->
+ Ret (e0 `BPush` (d1 (typeOf e), e1))
+ (d1op op $ EVar ext (d1 (typeOf e)) IZ)
+ (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ))
+ (EVar ext (d2 (opt2 op)) IZ))
+ (weakenExpr (WCopy (wSinks @[_,_])) e2))
-ext :: Const () a
-ext = Const ()
+ e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)
diff --git a/src/Example.hs b/src/Example.hs
new file mode 100644
index 0000000..89b2082
--- /dev/null
+++ b/src/Example.hs
@@ -0,0 +1,22 @@
+{-# LANGUAGE DataKinds #-}
+module Example where
+
+import AST
+import CHAD
+
+
+bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
+bin op a b = EOp ext op (EPair ext a b)
+
+-- x y |- x * y + x
+ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex1 =
+ bin (OAdd STF32)
+ (bin (OMul STF32)
+ (EVar ext (STScal STF32) (IS IZ))
+ (EVar ext (STScal STF32) IZ))
+ (EVar ext (STScal STF32) (IS IZ))
+
+-- -- x y |- let z = x + y in z * (z + x)
+-- ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+-- ex2 = _