summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs518
1 files changed, 290 insertions, 228 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8319080..1fd34d8 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -20,9 +20,12 @@
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
+
+-- TODO DO NOT COMMIT THIS
+{-# OPTIONS -Wno-unused-top-binds #-}
module CHAD (
- drev,
- freezeRet,
+ -- drev,
+ -- freezeRet,
Storage(..),
Descr(..),
Select,
@@ -44,6 +47,8 @@ import Data
import Lemmas
+------------------------------ TAPES AND BINDINGS ------------------------------
+
type family Tape binds where
Tape '[] = TNil
Tape (t : ts) = TPair t (Tape ts)
@@ -186,6 +191,10 @@ reconstructBindings binds tape =
(bconcat (mapBindings fromUnfExpr unf) build)
,sreverse (stapeUnfoldings binds))
+
+--------------------------------- VECTORISATION --------------------------------
+-- Currently only used in D[build1], should be removed.
+
type family Vectorise n list where
Vectorise _ '[] = '[]
Vectorise n (t : ts) = TArr n t : Vectorise n ts
@@ -231,12 +240,131 @@ vectorise1Binds env n (bs `BPush` (t, e)) =
(vectoriseExpr SNil (bindingsBinds bs) env e)
in bs' `BPush` (STArr (SS SZ) t, e')
+
+---------------------- ENVIRONMENT DESCRIPTION AND STORAGE ---------------------
+
+type Storage :: Symbol -> Type
+data Storage s where
+ SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
+ SMerge :: Storage "merge" -- ^ just return and merge
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+ DTop :: Descr '[] '[]
+ DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _)) = t `SCons` descrList des
+
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
+ SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
Select '[] '[] _ = '[]
Select (t : ts) (s : sto) s = t : Select ts sto s
Select (_ : ts) (_ : sto) s = Select ts sto s
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, SAccum)) = select s des
+select s@SAccum (DPush des (_, SMerge)) = select s des
+select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+
+
+---------------------------------- DERIVATIVES ---------------------------------
+
+d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
+d1op (OAdd t) e = EOp ext (OAdd t) e
+d1op (OMul t) e = EOp ext (OMul t) e
+d1op (ONeg t) e = EOp ext (ONeg t) e
+d1op (OLt t) e = EOp ext (OLt t) e
+d1op (OLe t) e = EOp ext (OLe t) e
+d1op (OEq t) e = EOp ext (OEq t) e
+d1op ONot e = EOp ext ONot e
+d1op OAnd e = EOp ext OAnd e
+d1op OOr e = EOp ext OOr e
+d1op OIf e = EOp ext OIf e
+d1op ORound64 e = EOp ext ORound64 e
+d1op OToFl64 e = EOp ext OToFl64 e
+
+-- | Both primal and dual must be duplicable expressions
+data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
+ | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
+
+d2op :: SOp a t -> D2Op a t
+d2op op = case op of
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
+ EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
+ OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ ONot -> Linear $ \_ -> ENil ext
+ OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OIf -> Linear $ \_ -> ENil ext
+ ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ OToFl64 -> Linear $ \_ -> ENil ext
+ where
+ d2opUnArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TScal a) t)
+ -> D2Op (TScal a) t
+ d2opUnArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> ENil ext
+ STI64 -> Linear $ \_ -> ENil ext
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> ENil ext
+
+ d2opBinArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
+ -> D2Op (TPair (TScal a) (TScal a)) t
+ d2opBinArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+
+sD1eEnv :: Descr env sto -> SList STy (D1E env)
+sD1eEnv DTop = SNil
+sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
+
+-- d1W :: env :> env' -> D1E env :> D1E env'
+-- d1W WId = WId
+-- d1W WSink = WSink
+-- d1W (WCopy w) = WCopy (d1W w)
+-- d1W (WPop w) = WPop (d1W w)
+-- d1W (WThen u w) = WThen (d1W u) (d1W w)
+
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
@@ -250,6 +378,16 @@ conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)
conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)
conv2Idx DTop i = case i of {}
+
+------------------------------------ LEMMAS ------------------------------------
+
+indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
+indexTupD1Id SZ = Refl
+indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+
+------------------------------------ MONOIDS -----------------------------------
+
zero :: STy t -> Ex env (D2 t)
zero = EZero
-- TODO: this original definition needs to be used as the post-processing after
@@ -312,9 +450,77 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
-indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
-indexTupD1Id SZ = Refl
-indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+------------------------------------ SUBENVS -----------------------------------
+
+subenvPlus :: SList STy env
+ -> Subenv env env1 -> Subenv env env2
+ -> (forall env3. Subenv env env3
+ -> Subenv env3 env1
+ -> Subenv env3 env2
+ -> (Ex exenv (Tup (D2E env1))
+ -> Ex exenv (Tup (D2E env2))
+ -> Ex exenv (Tup (D2E env3)))
+ -> r)
+ -> r
+subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
+subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SENo sub3) s31 s32 pl
+subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ))
+subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
+ ELet ext e2 $
+ EPair ext (pl (weakenExpr WSink e1)
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))
+subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus t
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ)))
+
+expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SETop _ = ENil ext
+expandSubenvZeros (SCons t ts) (SEYes sub) e =
+ ELet ext e $
+ let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
+ in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
+expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
+
+assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
+assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
+assertSubenvEmpty SETop = Refl
+assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
+
+popFromScope
+ :: Descr env0 sto
+ -> STy a
+ -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
+ -> Ex env (Tup (D2E envSub))
+ -> (forall envSub'.
+ Subenv (Select env0 sto "merge") envSub'
+ -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
+ -> r)
+ -> r
+popFromScope _ ty sub e k = case sub of
+ SEYes sub' -> k sub' e
+ SENo sub' -> k sub' $ EPair ext e (zero ty)
+
+
+--------------------------------- ACCUMULATORS ---------------------------------
accumPromote :: forall dt env sto proxy r.
proxy dt
@@ -411,252 +617,98 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
+
+---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
+
data Ret env0 sto t =
- forall shbinds env0Merge.
+ forall shbinds tapebinds env0Merge.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+ (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
deriving instance Show (Ret env0 sto t)
-data RetPair env0 sto env shbinds t =
+data RetPair env0 sto env shbinds tapebinds t =
forall env0Merge.
RetPair (Ex (Append shbinds env) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (RetPair env0 sto env shbinds t)
+ (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
data Rets env0 sto env list =
- forall shbinds.
+ forall shbinds tapebinds.
Rets (Bindings Ex env shbinds)
- (SList (RetPair env0 sto env shbinds) list)
+ (Subenv shbinds tapebinds)
+ (SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)
-subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
- -> r)
- -> r
-subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
-subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e2 $
- EPair ext (pl (weakenExpr WSink e1)
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e1 $
- ELet ext (weakenExpr WSink e2) $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (plus t
- (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ)))
-
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
-
-assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
-assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
-assertSubenvEmpty SETop = Refl
-assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
-
-popFromScope
- :: Descr env0 sto
- -> STy a
- -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
- -> Ex env (Tup (D2E envSub))
- -> (forall envSub'.
- Subenv (Select env0 sto "merge") envSub'
- -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
- -> r)
- -> r
-popFromScope _ ty sub e k = case sub of
- SEYes sub' -> k sub' e
- SENo sub' -> k sub' $ EPair ext e (zero ty)
-
--- d1W :: env :> env' -> D1E env :> D1E env'
--- d1W WId = WId
--- d1W WSink = WSink
--- d1W (WCopy w) = WCopy (d1W w)
--- d1W (WPop w) = WPop (d1W w)
--- d1W (WThen u w) = WThen (d1W u) (d1W w)
-
-weakenRetPair :: SList STy shbinds -> env :> env' -> RetPair env0 sto env shbinds t -> RetPair env0 sto env' shbinds t
+weakenRetPair :: SList STy shbinds -> env :> env'
+ -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
-weakenRets w (Rets binds list) =
+weakenRets w (Rets binds tapesub list) =
let (binds', _) = weakenBindings weakenExpr w binds
- in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-
-rebaseRetPair :: forall env b1 b2 env0 sto t f.
- Descr env0 sto -> SList f b1 -> SList f b2
- -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
-rebaseRetPair descr b1 b2 (RetPair p sub d)
+ in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
+
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f.
+ Descr env0 sto
+ -> SList f b1 -> SList f b2
+ -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
RetPair p sub (weakenExpr (autoWeak
- (#d (auto1 @(D2 t)) &. #b2 b2 &. #b1 b1 &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#b2 :++: #tl))
- (#d :++: ((#b2 :++: #b1) :++: #tl)))
+ (#d (auto1 @(D2 t))
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
d)
retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
-retConcat _ SNil = Rets BTop SNil
-retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
- | Rets binds1 pairs1 <- retConcat descr list
- , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1)
- , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0)
- , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum"))
+retConcat _ SNil = Rets BTop SETop SNil
+retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list)
+ | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
+ <- weakenRets (sinkWithBindings b) (retConcat descr list)
+ , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
+ , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
= Rets (bconcat b binds)
+ (subenvConcat subtape subtape2)
(SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
sub
- (weakenExpr (WCopy (sinkWithBindings binds)) d))
- (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)) pairs))
-
-d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
-d1op (OAdd t) e = EOp ext (OAdd t) e
-d1op (OMul t) e = EOp ext (OMul t) e
-d1op (ONeg t) e = EOp ext (ONeg t) e
-d1op (OLt t) e = EOp ext (OLt t) e
-d1op (OLe t) e = EOp ext (OLe t) e
-d1op (OEq t) e = EOp ext (OEq t) e
-d1op ONot e = EOp ext ONot e
-d1op OAnd e = EOp ext OAnd e
-d1op OOr e = EOp ext OOr e
-d1op OIf e = EOp ext OIf e
-d1op ORound64 e = EOp ext ORound64 e
-d1op OToFl64 e = EOp ext OToFl64 e
-
--- | Both primal and dual must be duplicable expressions
-data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
- | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
-
-d2op :: SOp a t -> D2Op a t
-d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
- OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
- ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
- OToFl64 -> Linear $ \_ -> ENil ext
- where
- d2opUnArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => D2Op (TScal a) t)
- -> D2Op (TScal a) t
- d2opUnArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENil ext
- STI64 -> Linear $ \_ -> ENil ext
- STF32 -> float
- STF64 -> float
- STBool -> Linear $ \_ -> ENil ext
-
- d2opBinArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
- -> D2Op (TPair (TScal a) (TScal a)) t
- d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- STF32 -> float
- STF64 -> float
- STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
-
-type Storage :: Symbol -> Type
-data Storage s where
- SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
- SMerge :: Storage "merge" -- ^ just return and merge
-deriving instance Show (Storage s)
-
--- | Environment description
-data Descr env sto where
- DTop :: Descr '[] '[]
- DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
-deriving instance Show (Descr env sto)
-
-descrList :: Descr env sto -> SList STy env
-descrList DTop = SNil
-descrList (des `DPush` (t, _)) = t `SCons` descrList des
-
-select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
-select _ DTop = SNil
-select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
-select s@SMerge (DPush des (_, SAccum)) = select s des
-select s@SAccum (DPush des (_, SMerge)) = select s des
-select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-
--- | This could have more precise typing on the output storage.
-subDescr :: Descr env sto -> Subenv env env'
- -> (forall sto'. Descr env' sto'
- -> Subenv (Select env sto "merge") (Select env' sto' "merge")
- -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
- -> Subenv (D1E env) (D1E env')
- -> r)
- -> r
-subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, sto)) (SEYes sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
-subDescr (des `DPush` (_, sto)) (SENo sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
- SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
-
-sD1eEnv :: Descr env sto -> SList STy (D1E env)
-sD1eEnv DTop = SNil
-sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
-
-d2ace :: SList STy env -> SList STy (D2AcE env)
-d2ace SNil = SNil
-d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) d))
+ (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)
+ subtape subtape2)
+ pairs))
freezeRet :: Descr env sto
-> Ret env sto t
-> Ex (D1E env) (D2 t) -- the incoming cotangent value
-> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 e1 sub e2) d =
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) d =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0
- e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
+ e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
(ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $
- ELet ext e2' $
+ ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (sD1eEnv descr))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
+ (#d :++: #shbinds :++: #d2ace :++: #tl))
+ e2') $
expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+
+---------------------------- THE CHAD TRANSFORMATION ---------------------------
+
drev :: forall env sto t.
Descr env sto
-> Ex env t -> Ret env sto t
@@ -665,50 +717,55 @@ drev des = \case
case conv2Idx des i of
Left accI ->
Ret BTop
+ SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
(EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
Right tupI ->
Ret BTop
+ SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvOnehot (select SMerge des) tupI)
(EPair ext (ENil ext) (EVar ext (d2 t) IZ))
ELet _ rhs body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
- , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
+ <- drev des rhs
+ , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2
+ <- drev (des `DPush` (typeOf rhs, SMerge)) body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum"))
- , Refl <- lemAppendNil @body_shbinds ->
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' ->
subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
+ (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
(weakenExpr wbody0' body1)
subBoth
(ELet ext
(weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (bindingsBinds body0)
- &. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0))
+ &. #body (subList (bindingsBinds body0) subtapeBody)
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
- (#d :++: #body :++: #rhs :++: #tl))
+ (#d :++: (#body :++: #rhs) :++: #tl))
body2') $
ELet ext
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
(EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
- | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
<- retConcat des $ drev des a `SCons` drev des b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
+ subtape
(EPair ext a1 b1)
subBoth
(ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
@@ -722,28 +779,31 @@ drev des = \case
(EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))
EFst _ e
- | Ret e0 e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
+ subtape
(EFst ext e1)
sub
(ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
+ subtape
(ESnd ext e1)
sub
(ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
+ ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
EInl _ t2 e
- | Ret e0 e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des e ->
Ret e0
+ subtape
(EInl ext (d1 t2) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
@@ -753,8 +813,9 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))
EInr _ t1 e
- | Ret e0 e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des e ->
Ret e0
+ subtape
(EInr ext (d1 t1) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
@@ -763,6 +824,7 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
+{-
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
, Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e
@@ -962,7 +1024,6 @@ drev des = \case
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (bindingsBinds e0) in
let collectexpr = bindingsCollect e0 in
- -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in
Ret (she0 `BPush` (shty, she1)
`BPush` (STArr ndim tapety
,EBuild ext ndim
@@ -1129,6 +1190,7 @@ drev des = \case
EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
+-}
where
err_accum = error "Accumulator operations unsupported in the source program"