diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 518 |
1 files changed, 290 insertions, 228 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 8319080..1fd34d8 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -20,9 +20,12 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} + +-- TODO DO NOT COMMIT THIS +{-# OPTIONS -Wno-unused-top-binds #-} module CHAD ( - drev, - freezeRet, + -- drev, + -- freezeRet, Storage(..), Descr(..), Select, @@ -44,6 +47,8 @@ import Data import Lemmas +------------------------------ TAPES AND BINDINGS ------------------------------ + type family Tape binds where Tape '[] = TNil Tape (t : ts) = TPair t (Tape ts) @@ -186,6 +191,10 @@ reconstructBindings binds tape = (bconcat (mapBindings fromUnfExpr unf) build) ,sreverse (stapeUnfoldings binds)) + +--------------------------------- VECTORISATION -------------------------------- +-- Currently only used in D[build1], should be removed. + type family Vectorise n list where Vectorise _ '[] = '[] Vectorise n (t : ts) = TArr n t : Vectorise n ts @@ -231,12 +240,131 @@ vectorise1Binds env n (bs `BPush` (t, e)) = (vectoriseExpr SNil (bindingsBinds bs) env e) in bs' `BPush` (STArr (SS SZ) t, e') + +---------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- + +type Storage :: Symbol -> Type +data Storage s where + SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator + SMerge :: Storage "merge" -- ^ just return and merge +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where + DTop :: Descr '[] '[] + DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) + SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + -- | Select only the types from the environment that have the specified storage type family Select env sto s where Select '[] '[] _ = '[] Select (t : ts) (s : sto) s = t : Select ts sto s Select (_ : ts) (_ : sto) s = Select ts sto s +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SAccum (DPush des (_, SMerge)) = select s des +select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) + + +---------------------------------- DERIVATIVES --------------------------------- + +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e +d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e + +-- | Both primal and dual must be duplicable expressions +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) + | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) + +d2op :: SOp a t -> D2Op a t +d2op op = case op of + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) + OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> + EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d + OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + OToFl64 -> Linear $ \_ -> ENil ext + where + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TScal a) t) + -> D2Op (TScal a) t + d2opUnArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> ENil ext + STI64 -> Linear $ \_ -> ENil ext + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) + -> D2Op (TPair (TScal a) (TScal a)) t + d2opBinArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + +sD1eEnv :: Descr env sto -> SList STy (D1E env) +sD1eEnv DTop = SNil +sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) + conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) @@ -250,6 +378,16 @@ conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) conv2Idx DTop i = case i of {} + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + + +------------------------------------ MONOIDS ----------------------------------- + zero :: STy t -> Ex env (D2 t) zero = EZero -- TODO: this original definition needs to be used as the post-processing after @@ -312,9 +450,77 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) -indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) -indexTupD1Id SZ = Refl -indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +------------------------------------ SUBENVS ----------------------------------- + +subenvPlus :: SList STy env + -> Subenv env env1 -> Subenv env env2 + -> (forall env3. Subenv env env3 + -> Subenv env3 env1 + -> Subenv env3 env2 + -> (Ex exenv (Tup (D2E env1)) + -> Ex exenv (Tup (D2E env2)) + -> Ex exenv (Tup (D2E env3))) + -> r) + -> r +subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) +subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl +subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = + subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ)) +subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = + subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> + ELet ext e2 $ + EPair ext (pl (weakenExpr WSink e1) + (EFst ext (EVar ext (typeOf e2) IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)) +subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = + subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus t + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ))) + +expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SETop _ = ENil ext +expandSubenvZeros (SCons t ts) (SEYes sub) e = + ELet ext e $ + let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ + in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) +expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) + +assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] +assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl +assertSubenvEmpty SETop = Refl +assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" + +popFromScope + :: Descr env0 sto + -> STy a + -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub + -> Ex env (Tup (D2E envSub)) + -> (forall envSub'. + Subenv (Select env0 sto "merge") envSub' + -> Ex env (TPair (Tup (D2E envSub')) (D2 a)) + -> r) + -> r +popFromScope _ ty sub e k = case sub of + SEYes sub' -> k sub' e + SENo sub' -> k sub' $ EPair ext e (zero ty) + + +--------------------------------- ACCUMULATORS --------------------------------- accumPromote :: forall dt env sto proxy r. proxy dt @@ -411,252 +617,98 @@ uninvertTup (t `SCons` list) tcore e = (ESnd ext (EVar ext recT IZ)) (ESnd ext (EFst ext (EVar ext recT IZ)))) + +---------------------------- RETURN TRIPLE FROM CHAD --------------------------- + data Ret env0 sto t = - forall shbinds env0Merge. + forall shbinds tapebinds env0Merge. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) + (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) deriving instance Show (Ret env0 sto t) -data RetPair env0 sto env shbinds t = +data RetPair env0 sto env shbinds tapebinds t = forall env0Merge. RetPair (Ex (Append shbinds env) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds t) + (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) +deriving instance Show (RetPair env0 sto env shbinds tapebinds t) data Rets env0 sto env list = - forall shbinds. + forall shbinds tapebinds. Rets (Bindings Ex env shbinds) - (SList (RetPair env0 sto env shbinds) list) + (Subenv shbinds tapebinds) + (SList (RetPair env0 sto env shbinds tapebinds) list) deriving instance Show (Rets env0 sto env list) -subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) - -> r) - -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (plus t - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) - -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) - -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" - -popFromScope - :: Descr env0 sto - -> STy a - -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub - -> Ex env (Tup (D2E envSub)) - -> (forall envSub'. - Subenv (Select env0 sto "merge") envSub' - -> Ex env (TPair (Tup (D2E envSub')) (D2 a)) - -> r) - -> r -popFromScope _ ty sub e k = case sub of - SEYes sub' -> k sub' e - SENo sub' -> k sub' $ EPair ext e (zero ty) - --- d1W :: env :> env' -> D1E env :> D1E env' --- d1W WId = WId --- d1W WSink = WSink --- d1W (WCopy w) = WCopy (d1W w) --- d1W (WPop w) = WPop (d1W w) --- d1W (WThen u w) = WThen (d1W u) (d1W w) - -weakenRetPair :: SList STy shbinds -> env :> env' -> RetPair env0 sto env shbinds t -> RetPair env0 sto env' shbinds t +weakenRetPair :: SList STy shbinds -> env :> env' + -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list -weakenRets w (Rets binds list) = +weakenRets w (Rets binds tapesub list) = let (binds', _) = weakenBindings weakenExpr w binds - in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list) - -rebaseRetPair :: forall env b1 b2 env0 sto t f. - Descr env0 sto -> SList f b1 -> SList f b2 - -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t -rebaseRetPair descr b1 b2 (RetPair p sub d) + in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) + +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. + Descr env0 sto + -> SList f b1 -> SList f b2 + -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) &. #b2 b2 &. #b1 b1 &. #tl (d2ace (select SAccum descr))) - (#d :++: (#b2 :++: #tl)) - (#d :++: ((#b2 :++: #b1) :++: #tl))) + (#d (auto1 @(D2 t)) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) d) retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat _ SNil = Rets BTop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) - | Rets binds1 pairs1 <- retConcat descr list - , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1) - , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0) - , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum")) +retConcat _ SNil = Rets BTop SETop SNil +retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) + | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs + <- weakenRets (sinkWithBindings b) (retConcat descr list) + , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) + , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) = Rets (bconcat b binds) + (subenvConcat subtape subtape2) (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) sub - (weakenExpr (WCopy (sinkWithBindings binds)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)) pairs)) - -d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) -d1op (OAdd t) e = EOp ext (OAdd t) e -d1op (OMul t) e = EOp ext (OMul t) e -d1op (ONeg t) e = EOp ext (ONeg t) e -d1op (OLt t) e = EOp ext (OLt t) e -d1op (OLe t) e = EOp ext (OLe t) e -d1op (OEq t) e = EOp ext (OEq t) e -d1op ONot e = EOp ext ONot e -d1op OAnd e = EOp ext OAnd e -d1op OOr e = EOp ext OOr e -d1op OIf e = EOp ext OIf e -d1op ORound64 e = EOp ext ORound64 e -d1op OToFl64 e = EOp ext OToFl64 e - --- | Both primal and dual must be duplicable expressions -data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) - | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) - -d2op :: SOp a t -> D2Op a t -d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) - OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) - ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 - OToFl64 -> Linear $ \_ -> ENil ext - where - d2opUnArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TScal a) t) - -> D2Op (TScal a) t - d2opUnArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENil ext - STI64 -> Linear $ \_ -> ENil ext - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENil ext - - d2opBinArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) - -> D2Op (TPair (TScal a) (TScal a)) t - d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - -type Storage :: Symbol -> Type -data Storage s where - SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator - SMerge :: Storage "merge" -- ^ just return and merge -deriving instance Show (Storage s) - --- | Environment description -data Descr env sto where - DTop :: Descr '[] '[] - DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) -deriving instance Show (Descr env sto) - -descrList :: Descr env sto -> SList STy env -descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des - -select :: Storage s -> Descr env sto -> SList STy (Select env sto s) -select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) - --- | This could have more precise typing on the output storage. -subDescr :: Descr env sto -> Subenv env env' - -> (forall sto'. Descr env' sto' - -> Subenv (Select env sto "merge") (Select env' sto' "merge") - -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) - -> Subenv (D1E env) (D1E env') - -> r) - -> r -subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) - SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) - -sD1eEnv :: Descr env sto -> SList STy (D1E env) -sD1eEnv DTop = SNil -sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) + (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) + (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) + subtape subtape2) + pairs)) freezeRet :: Descr env sto -> Ret env sto t -> Ex (D1E env) (D2 t) -- the incoming cotangent value -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 e1 sub e2) d = +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) d = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $ - ELet ext e2' $ + ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (sD1eEnv descr)) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) + (#d :++: #shbinds :++: #d2ace :++: #tl)) + e2') $ expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) + +---------------------------- THE CHAD TRANSFORMATION --------------------------- + drev :: forall env sto t. Descr env sto -> Ex env t -> Ret env sto t @@ -665,50 +717,55 @@ drev des = \case case conv2Idx des i of Left accI -> Ret BTop + SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) Right tupI -> Ret BTop + SETop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) ELet _ rhs body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs - , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body + | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 + <- drev des rhs + , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 + <- drev (des `DPush` (typeOf rhs, SMerge)) body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum")) - , Refl <- lemAppendNil @body_shbinds -> + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' -> subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) (weakenExpr wbody0' body1) subBoth (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (bindingsBinds body0) - &. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0)) + &. #body (subList (bindingsBinds body0) subtapeBody) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) - (#d :++: #body :++: #rhs :++: #tl)) + (#d :++: (#body :++: #rhs) :++: #tl)) body2') $ ELet ext (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b - | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> Ret binds + subtape (EPair ext a1 b1) subBoth (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) @@ -722,28 +779,31 @@ drev des = \case (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))) EFst _ e - | Ret e0 e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 + subtape (EFst ext e1) sub (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 + subtape (ESnd ext e1) sub (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) + ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) EInl _ t2 e - | Ret e0 e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des e -> Ret e0 + subtape (EInl ext (d1 t2) e1) sub (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) @@ -753,8 +813,9 @@ drev des = \case (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))) EInr _ t1 e - | Ret e0 e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des e -> Ret e0 + subtape (EInr ext (d1 t1) e1) sub (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) @@ -763,6 +824,7 @@ drev des = \case (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") (weakenExpr (WCopy (wSinks' @[_,_])) e2))) +{- ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e @@ -962,7 +1024,6 @@ drev des = \case case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (bindingsBinds e0) in let collectexpr = bindingsCollect e0 in - -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in Ret (she0 `BPush` (shty, she1) `BPush` (STArr ndim tapety ,EBuild ext ndim @@ -1129,6 +1190,7 @@ drev des = \case EAccum{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid +-} where err_accum = error "Accumulator operations unsupported in the source program" |