summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2023-09-21 09:29:05 +0200
committerTom Smeding <tom@tomsmeding.com>2023-09-21 09:29:05 +0200
commit8d07a43f0b364156433dc453b9d1cc762c032634 (patch)
tree55af847d92daf708fbea995a2f58cc09144cea70
parent897fefce372f00d3e904e83eb92c83e3e653b5be (diff)
WIP mixed environment description
-rw-r--r--src/CHAD.hs156
1 files changed, 104 insertions, 52 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 9a1c7d2..0c856b1 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -9,8 +9,14 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
module CHAD where
+import Data.Bifunctor (first, second)
+import Data.Kind (Type)
+import GHC.TypeLits (Symbol)
+
import AST
@@ -104,6 +110,12 @@ type family D2E env where
D2E '[] = '[]
D2E (t : env) = D2 t : D2E env
+-- | Select only the types from the environment that have the specified storage
+type family Select env sto s where
+ Select '[] '[] _ = '[]
+ Select (t : ts) (s : sto) s = t : Select ts sto s
+ Select (_ : ts) (_ : sto) s = Select ts sto s
+
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
@@ -125,21 +137,17 @@ d2 (STScal t) = case t of
STBool -> STNil
d2 STEVM{} = error "EVM not allowed in input program"
-d2e :: SList STy list -> SList STy (D2E list)
-d2e SNil = SNil
-d2e (SCons t list) = SCons (d2 t) (d2e list)
-
-d2list :: SList STy env -> SList STy (D2E env)
-d2list SNil = SNil
-d2list (SCons x l) = SCons (d2 x) (d2list l)
-
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
-conv2Idx :: Idx env t -> Idx (D2E env) (D2 t)
-conv2Idx IZ = IZ
-conv2Idx (IS i) = IS (conv2Idx i)
+conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t))
+ (Idx (D2E (Select env sto "merge")) (D2 t))
+conv2Idx (DPush _ _ SAccum) IZ = Left IZ
+conv2Idx (DPush _ _ SMerge) IZ = Right IZ
+conv2Idx (DPush des _ SAccum) (IS i) = first IS (conv2Idx des i)
+conv2Idx (DPush des _ SMerge) (IS i) = second IS (conv2Idx des i)
+conv2Idx DTop i = case i of {}
zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
@@ -154,23 +162,35 @@ zero (STScal t) = case t of
STBool -> ENil ext
zero STEVM{} = error "EVM not allowed in input program"
-data Ret env t =
+type family Tup env where
+ Tup '[] = TNil
+ Tup (t : ts) = TPair t (Tup ts)
+
+tTup :: SList STy env -> STy (Tup env)
+tTup SNil = STNil
+tTup (SCons t ts) = STPair t (tTup ts)
+
+zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
+zeroTup SNil = ENil ext
+zeroTup (SCons t env) = EPair ext (zero t) (zeroTup env)
+
+data Ret env sto t =
forall env'.
Ret (Bindings Ex (D1E env) env')
(Ex env' (D1 t))
- (Ex (D2 t : env') (TEVM (D2E env) TNil))
-deriving instance Show (Ret env t)
+ (Ex (D2 t : env') (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge")))))
+deriving instance Show (Ret env sto t)
-data RetPair env0 env t =
+data RetPair env0 sto env t =
RetPair (Ex env (D1 t))
- (Ex (D2 t : env) (TEVM (D2E env0) TNil))
+ (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E (Select env0 sto "merge")))))
deriving (Show)
-data Rets env0 env list =
+data Rets env0 sto env list =
forall env'.
Rets (Bindings Ex env env')
- (SList (RetPair env0 env') list)
-deriving instance Show (Rets env0 env list)
+ (SList (RetPair env0 sto env') list)
+deriving instance Show (Rets env0 sto env list)
-- d1W :: env :> env' -> D1E env :> D1E env'
-- d1W WId = WId
@@ -179,15 +199,15 @@ deriving instance Show (Rets env0 env list)
-- d1W (WPop w) = WPop (d1W w)
-- d1W (WThen u w) = WThen (d1W u) (d1W w)
-weakenRetPair :: env :> env' -> RetPair env0 env t -> RetPair env0 env' t
+weakenRetPair :: env :> env' -> RetPair env0 sto env t -> RetPair env0 sto env' t
weakenRetPair w (RetPair e1 e2) = RetPair (weakenExpr w e1) (weakenExpr (WCopy w) e2)
-weakenRets :: env :> env' -> Rets env0 env list -> Rets env0 env' list
+weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
weakenRets w (Rets binds list) =
weakenBindings weakenExpr w binds $ \binds' wbinds' ->
Rets binds' (slistMap (weakenRetPair wbinds') list)
-retConcat :: forall env list. SList (Ret env) list -> Rets env (D1E env) list
+retConcat :: forall env sto list. SList (Ret env sto) list -> Rets env sto (D1E env) list
retConcat SNil = Rets BTop SNil
retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list)
| Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list)
@@ -206,10 +226,10 @@ d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
d1op OIf e = EOp ext OIf e
+-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
| Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
--- both primal and dual must be duplicable expressions
d2op :: SOp a t -> D2Op a t
d2op op = case op of
OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d)
@@ -243,21 +263,50 @@ d2op op = case op of
STF64 -> float
STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
-freezeRet :: Ret env t
+freezeRet :: Ret env sto t
-> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value
- -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E env) TNil))
+ -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge")))))
freezeRet (Ret e0 e1 e2) d = letBinds e0 $ EPair ext e1 (ELet ext d e2)
-drev :: SList STy env -> Ex env t -> Ret env t
-drev senv = \case
+type Storage :: Symbol -> Type
+data Storage s where
+ SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
+ SMerge :: Storage "merge" -- ^ just return and merge
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+ DTop :: Descr '[] '[]
+ DPush :: Descr env sto -> STy t -> Storage s -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des t SAccum) = SCons t (select s des)
+select s@SMerge (DPush des _ SAccum) = select s des
+select s@SAccum (DPush des _ SMerge) = select s des
+select s@SMerge (DPush des t SMerge) = SCons t (select s des)
+
+d2e :: SList STy env -> SList STy (D2E env)
+d2e SNil = SNil
+d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+
+drev :: Descr env sto -> Ex env t -> Ret env sto t
+drev des = \case
EVar _ t i ->
- Ret BTop
- (EVar ext (d1 t) (conv1Idx i))
- (EMOne (d2list senv) (conv2Idx i) (EVar ext (d2 t) IZ))
+ case conv2Idx des i of
+ Left accumI ->
+ Ret BTop
+ (EVar ext (d1 t) (conv1Idx i))
+ (EMBind
+ (EMOne d2mon accumI (EVar ext (d2 t) IZ))
+ (EMReturn d2mon (zeroTup (select SMerge des))))
+ Right tupI ->
+ _
ELet _ rhs body
- | Ret rhs0 rhs1 rhs2 <- drev senv rhs
- , Ret body0 body1 body2 <- drev (SCons (typeOf rhs) senv) body ->
+ | Ret rhs0 rhs1 rhs2 <- drev des rhs
+ , Ret body0 body1 body2 <- drev (DPush des (typeOf rhs) SMerge) body ->
weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' ->
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(weakenExpr wbody0' body1)
@@ -267,19 +316,19 @@ drev senv = \case
EPair _ a b
| Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil)
- <- retConcat $ drev senv a `SCons` drev senv b `SCons` SNil
+ <- retConcat $ drev des a `SCons` drev des b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
Ret binds
(EPair ext a1 b1)
(ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
- (EMReturn (d2e senv) (ENil ext))
+ (EMReturn d2mon (ENil ext))
(EMBind (ELet ext (EFst ext (EVar ext dt IZ))
(weakenExpr (WCopy (wSinks @[_,_])) a2))
(ELet ext (ESnd ext (EVar ext dt (IS IZ)))
(weakenExpr (WCopy (wSinks @[_,_,_])) b2))))
EFst _ e
- | Ret e0 e1 e2 <- drev senv e
+ | Ret e0 e1 e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
(EFst ext e1)
@@ -287,40 +336,40 @@ drev senv = \case
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 e1 e2 <- drev senv e
+ | Ret e0 e1 e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
(ESnd ext e1)
(ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop (ENil ext) (EMReturn (d2e senv) (ENil ext))
+ ENil _ -> Ret BTop (ENil ext) (EMReturn d2mon (ENil ext))
EInl _ t2 e
- | Ret e0 e1 e2 <- drev senv e ->
+ | Ret e0 e1 e2 <- drev des e ->
Ret e0
(EInl ext (d1 t2) e1)
(ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
- (EMReturn (d2e senv) (ENil ext))
+ (EMReturn d2mon (ENil ext))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks @[_,_])) e2)
- (EError (STEVM (d2e senv) STNil) "inl<-dinr")))
+ (EError (STEVM d2mon STNil) "inl<-dinr")))
EInr _ t1 e
- | Ret e0 e1 e2 <- drev senv e ->
+ | Ret e0 e1 e2 <- drev des e ->
Ret e0
(EInr ext (d1 t1) e1)
(ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
- (EMReturn (d2e senv) (ENil ext))
+ (EMReturn d2mon (ENil ext))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
- (EError (STEVM (d2e senv) STNil) "inr<-dinl")
+ (EError (STEVM d2mon STNil) "inr<-dinl")
(weakenExpr (WCopy (wSinks @[_,_])) e2)))
ECase _ e a b
| STEither t1 t2 <- typeOf e
- , Ret e0 e1 e2 <- drev senv e
- , Ret a0 a1 a2 <- drev (SCons t1 senv) a
- , Ret b0 b1 b2 <- drev (SCons t2 senv) b
+ , Ret e0 e1 e2 <- drev des e
+ , Ret a0 a1 a2 <- drev (DPush des t1 SMerge) a
+ , Ret b0 b1 b2 <- drev (DPush des t2 SMerge) b
, TupBinds tapeA collectA reconA <- tupBinds a0
, TupBinds tapeB collectB reconB <- tupBinds b0
, let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) ->
@@ -341,18 +390,18 @@ drev senv = \case
letBinds rebinds $
ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
EMBind (weakenExpr (WCopy wrebinds) (EMScope a2))
- (EMReturn (d2e senv)
+ (EMReturn d2mon
(EInr ext STNil (EInl ext (d2 t2)
(ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ))))))
- (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
+ (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
(ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
- (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
+ (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
(case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of
TupBindsReconstruct rebinds wrebinds ->
letBinds rebinds $
ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
EMBind (weakenExpr (WCopy wrebinds) (EMScope b2))
- (EMReturn (d2e senv)
+ (EMReturn d2mon
(EInr ext STNil (EInr ext (d2 t1)
(ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ))))))))
(weakenExpr (WCopy (wSinks @[_,_,_])) e2))
@@ -360,10 +409,10 @@ drev senv = \case
EConst _ t val ->
Ret BTop
(EConst ext t val)
- (EMReturn (d2e senv) (ENil ext))
+ (EMReturn d2mon (ENil ext))
EOp _ op e
- | Ret e0 e1 e2 <- drev senv e ->
+ | Ret e0 e1 e2 <- drev des e ->
case d2op op of
Linear d2opfun ->
Ret e0
@@ -378,3 +427,6 @@ drev senv = \case
(weakenExpr (WCopy (wSinks @[_,_])) e2))
e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)
+
+ where
+ d2mon = d2e (select SAccum des)