diff options
author | Tom Smeding <tom@tomsmeding.com> | 2023-09-21 09:29:05 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2023-09-21 09:29:05 +0200 |
commit | 8d07a43f0b364156433dc453b9d1cc762c032634 (patch) | |
tree | 55af847d92daf708fbea995a2f58cc09144cea70 | |
parent | 897fefce372f00d3e904e83eb92c83e3e653b5be (diff) |
WIP mixed environment description
-rw-r--r-- | src/CHAD.hs | 156 |
1 files changed, 104 insertions, 52 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 9a1c7d2..0c856b1 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -9,8 +9,14 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE StandaloneKindSignatures #-} module CHAD where +import Data.Bifunctor (first, second) +import Data.Kind (Type) +import GHC.TypeLits (Symbol) + import AST @@ -104,6 +110,12 @@ type family D2E env where D2E '[] = '[] D2E (t : env) = D2 t : D2E env +-- | Select only the types from the environment that have the specified storage +type family Select env sto s where + Select '[] '[] _ = '[] + Select (t : ts) (s : sto) s = t : Select ts sto s + Select (_ : ts) (_ : sto) s = Select ts sto s + d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) @@ -125,21 +137,17 @@ d2 (STScal t) = case t of STBool -> STNil d2 STEVM{} = error "EVM not allowed in input program" -d2e :: SList STy list -> SList STy (D2E list) -d2e SNil = SNil -d2e (SCons t list) = SCons (d2 t) (d2e list) - -d2list :: SList STy env -> SList STy (D2E env) -d2list SNil = SNil -d2list (SCons x l) = SCons (d2 x) (d2list l) - conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Idx env t -> Idx (D2E env) (D2 t) -conv2Idx IZ = IZ -conv2Idx (IS i) = IS (conv2Idx i) +conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t)) + (Idx (D2E (Select env sto "merge")) (D2 t)) +conv2Idx (DPush _ _ SAccum) IZ = Left IZ +conv2Idx (DPush _ _ SMerge) IZ = Right IZ +conv2Idx (DPush des _ SAccum) (IS i) = first IS (conv2Idx des i) +conv2Idx (DPush des _ SMerge) (IS i) = second IS (conv2Idx des i) +conv2Idx DTop i = case i of {} zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext @@ -154,23 +162,35 @@ zero (STScal t) = case t of STBool -> ENil ext zero STEVM{} = error "EVM not allowed in input program" -data Ret env t = +type family Tup env where + Tup '[] = TNil + Tup (t : ts) = TPair t (Tup ts) + +tTup :: SList STy env -> STy (Tup env) +tTup SNil = STNil +tTup (SCons t ts) = STPair t (tTup ts) + +zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) +zeroTup SNil = ENil ext +zeroTup (SCons t env) = EPair ext (zero t) (zeroTup env) + +data Ret env sto t = forall env'. Ret (Bindings Ex (D1E env) env') (Ex env' (D1 t)) - (Ex (D2 t : env') (TEVM (D2E env) TNil)) -deriving instance Show (Ret env t) + (Ex (D2 t : env') (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) +deriving instance Show (Ret env sto t) -data RetPair env0 env t = +data RetPair env0 sto env t = RetPair (Ex env (D1 t)) - (Ex (D2 t : env) (TEVM (D2E env0) TNil)) + (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E (Select env0 sto "merge"))))) deriving (Show) -data Rets env0 env list = +data Rets env0 sto env list = forall env'. Rets (Bindings Ex env env') - (SList (RetPair env0 env') list) -deriving instance Show (Rets env0 env list) + (SList (RetPair env0 sto env') list) +deriving instance Show (Rets env0 sto env list) -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId @@ -179,15 +199,15 @@ deriving instance Show (Rets env0 env list) -- d1W (WPop w) = WPop (d1W w) -- d1W (WThen u w) = WThen (d1W u) (d1W w) -weakenRetPair :: env :> env' -> RetPair env0 env t -> RetPair env0 env' t +weakenRetPair :: env :> env' -> RetPair env0 sto env t -> RetPair env0 sto env' t weakenRetPair w (RetPair e1 e2) = RetPair (weakenExpr w e1) (weakenExpr (WCopy w) e2) -weakenRets :: env :> env' -> Rets env0 env list -> Rets env0 env' list +weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds list) = weakenBindings weakenExpr w binds $ \binds' wbinds' -> Rets binds' (slistMap (weakenRetPair wbinds') list) -retConcat :: forall env list. SList (Ret env) list -> Rets env (D1E env) list +retConcat :: forall env sto list. SList (Ret env sto) list -> Rets env sto (D1E env) list retConcat SNil = Rets BTop SNil retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list) | Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list) @@ -206,10 +226,10 @@ d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e d1op OIf e = EOp ext OIf e +-- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) --- both primal and dual must be duplicable expressions d2op :: SOp a t -> D2Op a t d2op op = case op of OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d) @@ -243,21 +263,50 @@ d2op op = case op of STF64 -> float STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -freezeRet :: Ret env t +freezeRet :: Ret env sto t -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value - -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E env) TNil)) + -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) freezeRet (Ret e0 e1 e2) d = letBinds e0 $ EPair ext e1 (ELet ext d e2) -drev :: SList STy env -> Ex env t -> Ret env t -drev senv = \case +type Storage :: Symbol -> Type +data Storage s where + SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator + SMerge :: Storage "merge" -- ^ just return and merge +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where + DTop :: Descr '[] '[] + DPush :: Descr env sto -> STy t -> Storage s -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des t SAccum) = SCons t (select s des) +select s@SMerge (DPush des _ SAccum) = select s des +select s@SAccum (DPush des _ SMerge) = select s des +select s@SMerge (DPush des t SMerge) = SCons t (select s des) + +d2e :: SList STy env -> SList STy (D2E env) +d2e SNil = SNil +d2e (SCons t ts) = SCons (d2 t) (d2e ts) + +drev :: Descr env sto -> Ex env t -> Ret env sto t +drev des = \case EVar _ t i -> - Ret BTop - (EVar ext (d1 t) (conv1Idx i)) - (EMOne (d2list senv) (conv2Idx i) (EVar ext (d2 t) IZ)) + case conv2Idx des i of + Left accumI -> + Ret BTop + (EVar ext (d1 t) (conv1Idx i)) + (EMBind + (EMOne d2mon accumI (EVar ext (d2 t) IZ)) + (EMReturn d2mon (zeroTup (select SMerge des)))) + Right tupI -> + _ ELet _ rhs body - | Ret rhs0 rhs1 rhs2 <- drev senv rhs - , Ret body0 body1 body2 <- drev (SCons (typeOf rhs) senv) body -> + | Ret rhs0 rhs1 rhs2 <- drev des rhs + , Ret body0 body1 body2 <- drev (DPush des (typeOf rhs) SMerge) body -> weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' -> Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (weakenExpr wbody0' body1) @@ -267,19 +316,19 @@ drev senv = \case EPair _ a b | Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil) - <- retConcat $ drev senv a `SCons` drev senv b `SCons` SNil + <- retConcat $ drev des a `SCons` drev des b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> Ret binds (EPair ext a1 b1) (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) - (EMReturn (d2e senv) (ENil ext)) + (EMReturn d2mon (ENil ext)) (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy (wSinks @[_,_])) a2)) (ELet ext (ESnd ext (EVar ext dt (IS IZ))) (weakenExpr (WCopy (wSinks @[_,_,_])) b2)))) EFst _ e - | Ret e0 e1 e2 <- drev senv e + | Ret e0 e1 e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 (EFst ext e1) @@ -287,40 +336,40 @@ drev senv = \case weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 e1 e2 <- drev senv e + | Ret e0 e1 e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 (ESnd ext e1) (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop (ENil ext) (EMReturn (d2e senv) (ENil ext)) + ENil _ -> Ret BTop (ENil ext) (EMReturn d2mon (ENil ext)) EInl _ t2 e - | Ret e0 e1 e2 <- drev senv e -> + | Ret e0 e1 e2 <- drev des e -> Ret e0 (EInl ext (d1 t2) e1) (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) - (EMReturn (d2e senv) (ENil ext)) + (EMReturn d2mon (ENil ext)) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks @[_,_])) e2) - (EError (STEVM (d2e senv) STNil) "inl<-dinr"))) + (EError (STEVM d2mon STNil) "inl<-dinr"))) EInr _ t1 e - | Ret e0 e1 e2 <- drev senv e -> + | Ret e0 e1 e2 <- drev des e -> Ret e0 (EInr ext (d1 t1) e1) (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) - (EMReturn (d2e senv) (ENil ext)) + (EMReturn d2mon (ENil ext)) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError (STEVM (d2e senv) STNil) "inr<-dinl") + (EError (STEVM d2mon STNil) "inr<-dinl") (weakenExpr (WCopy (wSinks @[_,_])) e2))) ECase _ e a b | STEither t1 t2 <- typeOf e - , Ret e0 e1 e2 <- drev senv e - , Ret a0 a1 a2 <- drev (SCons t1 senv) a - , Ret b0 b1 b2 <- drev (SCons t2 senv) b + , Ret e0 e1 e2 <- drev des e + , Ret a0 a1 a2 <- drev (DPush des t1 SMerge) a + , Ret b0 b1 b2 <- drev (DPush des t2 SMerge) b , TupBinds tapeA collectA reconA <- tupBinds a0 , TupBinds tapeB collectB reconB <- tupBinds b0 , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) -> @@ -341,18 +390,18 @@ drev senv = \case letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ EMBind (weakenExpr (WCopy wrebinds) (EMScope a2)) - (EMReturn (d2e senv) + (EMReturn d2mon (EInr ext STNil (EInl ext (d2 t2) (ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ)))))) - (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape")) + (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape")) (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) - (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape") + (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape") (case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ EMBind (weakenExpr (WCopy wrebinds) (EMScope b2)) - (EMReturn (d2e senv) + (EMReturn d2mon (EInr ext STNil (EInr ext (d2 t1) (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ)))))))) (weakenExpr (WCopy (wSinks @[_,_,_])) e2)) @@ -360,10 +409,10 @@ drev senv = \case EConst _ t val -> Ret BTop (EConst ext t val) - (EMReturn (d2e senv) (ENil ext)) + (EMReturn d2mon (ENil ext)) EOp _ op e - | Ret e0 e1 e2 <- drev senv e -> + | Ret e0 e1 e2 <- drev des e -> case d2op op of Linear d2opfun -> Ret e0 @@ -378,3 +427,6 @@ drev senv = \case (weakenExpr (WCopy (wSinks @[_,_])) e2)) e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) + + where + d2mon = d2e (select SAccum des) |