summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-27 22:02:18 +0100
committerTom Smeding <tom@tomsmeding.com>2024-10-27 22:02:18 +0100
commit4acecc1099caefc2bd2fbe252d30d52ead2c74be (patch)
treec32d4ae3fd7269df5543c0b8fe3cbbce87b1ce28 /src
parent8d15d92bb9a0472e096ff714e45b10cf16134a30 (diff)
WIP preserve only subset of D0 bindings in dual (...)
The point of this is to ensure that when an expression occurs in a Build, then the parts of D0 that are only there to make sharing work out for D1 are not laboriously taped in an array and preserved for D2, only for D2 to ignore them. However, while the subtape machinery is a good first step, this is not everything: the current Build translation makes a Build for the (elementwise) tape and separately a build for the primal. Because the primal _does_ generally need the subtaped-away stuff, we can't just not tape those. TODO: figure out how to resolve this / what the next step is.
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs6
-rw-r--r--src/AST/Count.hs12
-rw-r--r--src/AST/Env.hs20
-rw-r--r--src/CHAD.hs518
4 files changed, 314 insertions, 242 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 6370148..9f1da7a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -20,7 +20,6 @@ import Data.Functor.Const
import Data.Kind (Type)
import Array
-import AST.Env
import AST.Types
import AST.Weaken
import CHAD.Types
@@ -289,11 +288,6 @@ subst' f w = \case
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-wUndoSubenv :: Subenv env env' -> env' :> env
-wUndoSubenv SETop = WId
-wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
-wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
-
slistIdx :: SList f list -> Idx list t -> f t
slistIdx (SCons x _) IZ = x
slistIdx (SCons _ list) (IS i) = slistIdx list i
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 31720a5..a928743 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -150,12 +150,12 @@ deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
- subst (\x t i -> case sinkWithSubenv i sub of
+ subst (\x t i -> case sinkViaSubenv i sub of
Just i' -> EVar x t i'
Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
where
- sinkWithSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkWithSubenv IZ (SEYes _) = Just IZ
- sinkWithSubenv IZ (SENo _) = Nothing
- sinkWithSubenv (IS i) (SEYes sub) = IS <$> sinkWithSubenv i sub
- sinkWithSubenv (IS i) (SENo sub) = sinkWithSubenv i sub
+ sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
+ sinkViaSubenv IZ (SEYes _) = Just IZ
+ sinkViaSubenv IZ (SENo _) = Nothing
+ sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
+ sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
index c33bad3..4f34166 100644
--- a/src/AST/Env.hs
+++ b/src/AST/Env.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -14,8 +15,8 @@ import Data
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
data Subenv env env' where
SETop :: Subenv '[] '[]
- SEYes :: Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: Subenv env env' -> Subenv (t : env) env'
+ SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env')
+ SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env'
deriving instance Show (Subenv env env')
subList :: SList f env -> Subenv env env' -> SList f env'
@@ -41,3 +42,18 @@ subenvCompose SETop SETop = SETop
subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)
subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
+
+subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1')
+subenvConcat sub1 SETop = sub1
+subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2)
+subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
+
+sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0
+sinkWithSubenv SETop = WId
+sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub
+sinkWithSubenv (SENo sub) = sinkWithSubenv sub
+
+wUndoSubenv :: Subenv env env' -> env' :> env
+wUndoSubenv SETop = WId
+wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8319080..1fd34d8 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -20,9 +20,12 @@
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
+
+-- TODO DO NOT COMMIT THIS
+{-# OPTIONS -Wno-unused-top-binds #-}
module CHAD (
- drev,
- freezeRet,
+ -- drev,
+ -- freezeRet,
Storage(..),
Descr(..),
Select,
@@ -44,6 +47,8 @@ import Data
import Lemmas
+------------------------------ TAPES AND BINDINGS ------------------------------
+
type family Tape binds where
Tape '[] = TNil
Tape (t : ts) = TPair t (Tape ts)
@@ -186,6 +191,10 @@ reconstructBindings binds tape =
(bconcat (mapBindings fromUnfExpr unf) build)
,sreverse (stapeUnfoldings binds))
+
+--------------------------------- VECTORISATION --------------------------------
+-- Currently only used in D[build1], should be removed.
+
type family Vectorise n list where
Vectorise _ '[] = '[]
Vectorise n (t : ts) = TArr n t : Vectorise n ts
@@ -231,12 +240,131 @@ vectorise1Binds env n (bs `BPush` (t, e)) =
(vectoriseExpr SNil (bindingsBinds bs) env e)
in bs' `BPush` (STArr (SS SZ) t, e')
+
+---------------------- ENVIRONMENT DESCRIPTION AND STORAGE ---------------------
+
+type Storage :: Symbol -> Type
+data Storage s where
+ SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
+ SMerge :: Storage "merge" -- ^ just return and merge
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+ DTop :: Descr '[] '[]
+ DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _)) = t `SCons` descrList des
+
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
+ SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
Select '[] '[] _ = '[]
Select (t : ts) (s : sto) s = t : Select ts sto s
Select (_ : ts) (_ : sto) s = Select ts sto s
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, SAccum)) = select s des
+select s@SAccum (DPush des (_, SMerge)) = select s des
+select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+
+
+---------------------------------- DERIVATIVES ---------------------------------
+
+d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
+d1op (OAdd t) e = EOp ext (OAdd t) e
+d1op (OMul t) e = EOp ext (OMul t) e
+d1op (ONeg t) e = EOp ext (ONeg t) e
+d1op (OLt t) e = EOp ext (OLt t) e
+d1op (OLe t) e = EOp ext (OLe t) e
+d1op (OEq t) e = EOp ext (OEq t) e
+d1op ONot e = EOp ext ONot e
+d1op OAnd e = EOp ext OAnd e
+d1op OOr e = EOp ext OOr e
+d1op OIf e = EOp ext OIf e
+d1op ORound64 e = EOp ext ORound64 e
+d1op OToFl64 e = EOp ext OToFl64 e
+
+-- | Both primal and dual must be duplicable expressions
+data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
+ | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
+
+d2op :: SOp a t -> D2Op a t
+d2op op = case op of
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
+ EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
+ OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ ONot -> Linear $ \_ -> ENil ext
+ OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OIf -> Linear $ \_ -> ENil ext
+ ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ OToFl64 -> Linear $ \_ -> ENil ext
+ where
+ d2opUnArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TScal a) t)
+ -> D2Op (TScal a) t
+ d2opUnArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> ENil ext
+ STI64 -> Linear $ \_ -> ENil ext
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> ENil ext
+
+ d2opBinArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
+ -> D2Op (TPair (TScal a) (TScal a)) t
+ d2opBinArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+
+sD1eEnv :: Descr env sto -> SList STy (D1E env)
+sD1eEnv DTop = SNil
+sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
+
+-- d1W :: env :> env' -> D1E env :> D1E env'
+-- d1W WId = WId
+-- d1W WSink = WSink
+-- d1W (WCopy w) = WCopy (d1W w)
+-- d1W (WPop w) = WPop (d1W w)
+-- d1W (WThen u w) = WThen (d1W u) (d1W w)
+
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
@@ -250,6 +378,16 @@ conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)
conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)
conv2Idx DTop i = case i of {}
+
+------------------------------------ LEMMAS ------------------------------------
+
+indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
+indexTupD1Id SZ = Refl
+indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+
+------------------------------------ MONOIDS -----------------------------------
+
zero :: STy t -> Ex env (D2 t)
zero = EZero
-- TODO: this original definition needs to be used as the post-processing after
@@ -312,9 +450,77 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
-indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
-indexTupD1Id SZ = Refl
-indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+------------------------------------ SUBENVS -----------------------------------
+
+subenvPlus :: SList STy env
+ -> Subenv env env1 -> Subenv env env2
+ -> (forall env3. Subenv env env3
+ -> Subenv env3 env1
+ -> Subenv env3 env2
+ -> (Ex exenv (Tup (D2E env1))
+ -> Ex exenv (Tup (D2E env2))
+ -> Ex exenv (Tup (D2E env3)))
+ -> r)
+ -> r
+subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
+subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SENo sub3) s31 s32 pl
+subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ))
+subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
+ ELet ext e2 $
+ EPair ext (pl (weakenExpr WSink e1)
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))
+subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus t
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ)))
+
+expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SETop _ = ENil ext
+expandSubenvZeros (SCons t ts) (SEYes sub) e =
+ ELet ext e $
+ let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
+ in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
+expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
+
+assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
+assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
+assertSubenvEmpty SETop = Refl
+assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
+
+popFromScope
+ :: Descr env0 sto
+ -> STy a
+ -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
+ -> Ex env (Tup (D2E envSub))
+ -> (forall envSub'.
+ Subenv (Select env0 sto "merge") envSub'
+ -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
+ -> r)
+ -> r
+popFromScope _ ty sub e k = case sub of
+ SEYes sub' -> k sub' e
+ SENo sub' -> k sub' $ EPair ext e (zero ty)
+
+
+--------------------------------- ACCUMULATORS ---------------------------------
accumPromote :: forall dt env sto proxy r.
proxy dt
@@ -411,252 +617,98 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
+
+---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
+
data Ret env0 sto t =
- forall shbinds env0Merge.
+ forall shbinds tapebinds env0Merge.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+ (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
deriving instance Show (Ret env0 sto t)
-data RetPair env0 sto env shbinds t =
+data RetPair env0 sto env shbinds tapebinds t =
forall env0Merge.
RetPair (Ex (Append shbinds env) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (RetPair env0 sto env shbinds t)
+ (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
data Rets env0 sto env list =
- forall shbinds.
+ forall shbinds tapebinds.
Rets (Bindings Ex env shbinds)
- (SList (RetPair env0 sto env shbinds) list)
+ (Subenv shbinds tapebinds)
+ (SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)
-subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
- -> r)
- -> r
-subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
-subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e2 $
- EPair ext (pl (weakenExpr WSink e1)
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e1 $
- ELet ext (weakenExpr WSink e2) $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (plus t
- (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ)))
-
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
-
-assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
-assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
-assertSubenvEmpty SETop = Refl
-assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
-
-popFromScope
- :: Descr env0 sto
- -> STy a
- -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
- -> Ex env (Tup (D2E envSub))
- -> (forall envSub'.
- Subenv (Select env0 sto "merge") envSub'
- -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
- -> r)
- -> r
-popFromScope _ ty sub e k = case sub of
- SEYes sub' -> k sub' e
- SENo sub' -> k sub' $ EPair ext e (zero ty)
-
--- d1W :: env :> env' -> D1E env :> D1E env'
--- d1W WId = WId
--- d1W WSink = WSink
--- d1W (WCopy w) = WCopy (d1W w)
--- d1W (WPop w) = WPop (d1W w)
--- d1W (WThen u w) = WThen (d1W u) (d1W w)
-
-weakenRetPair :: SList STy shbinds -> env :> env' -> RetPair env0 sto env shbinds t -> RetPair env0 sto env' shbinds t
+weakenRetPair :: SList STy shbinds -> env :> env'
+ -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
-weakenRets w (Rets binds list) =
+weakenRets w (Rets binds tapesub list) =
let (binds', _) = weakenBindings weakenExpr w binds
- in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-
-rebaseRetPair :: forall env b1 b2 env0 sto t f.
- Descr env0 sto -> SList f b1 -> SList f b2
- -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
-rebaseRetPair descr b1 b2 (RetPair p sub d)
+ in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
+
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f.
+ Descr env0 sto
+ -> SList f b1 -> SList f b2
+ -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
RetPair p sub (weakenExpr (autoWeak
- (#d (auto1 @(D2 t)) &. #b2 b2 &. #b1 b1 &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#b2 :++: #tl))
- (#d :++: ((#b2 :++: #b1) :++: #tl)))
+ (#d (auto1 @(D2 t))
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
d)
retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
-retConcat _ SNil = Rets BTop SNil
-retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
- | Rets binds1 pairs1 <- retConcat descr list
- , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1)
- , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0)
- , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum"))
+retConcat _ SNil = Rets BTop SETop SNil
+retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list)
+ | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
+ <- weakenRets (sinkWithBindings b) (retConcat descr list)
+ , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
+ , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
= Rets (bconcat b binds)
+ (subenvConcat subtape subtape2)
(SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
sub
- (weakenExpr (WCopy (sinkWithBindings binds)) d))
- (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)) pairs))
-
-d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
-d1op (OAdd t) e = EOp ext (OAdd t) e
-d1op (OMul t) e = EOp ext (OMul t) e
-d1op (ONeg t) e = EOp ext (ONeg t) e
-d1op (OLt t) e = EOp ext (OLt t) e
-d1op (OLe t) e = EOp ext (OLe t) e
-d1op (OEq t) e = EOp ext (OEq t) e
-d1op ONot e = EOp ext ONot e
-d1op OAnd e = EOp ext OAnd e
-d1op OOr e = EOp ext OOr e
-d1op OIf e = EOp ext OIf e
-d1op ORound64 e = EOp ext ORound64 e
-d1op OToFl64 e = EOp ext OToFl64 e
-
--- | Both primal and dual must be duplicable expressions
-data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
- | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
-
-d2op :: SOp a t -> D2Op a t
-d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
- OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
- ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
- OToFl64 -> Linear $ \_ -> ENil ext
- where
- d2opUnArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => D2Op (TScal a) t)
- -> D2Op (TScal a) t
- d2opUnArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENil ext
- STI64 -> Linear $ \_ -> ENil ext
- STF32 -> float
- STF64 -> float
- STBool -> Linear $ \_ -> ENil ext
-
- d2opBinArrangeInt :: SScalTy a
- -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
- -> D2Op (TPair (TScal a) (TScal a)) t
- d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- STF32 -> float
- STF64 -> float
- STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
-
-type Storage :: Symbol -> Type
-data Storage s where
- SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
- SMerge :: Storage "merge" -- ^ just return and merge
-deriving instance Show (Storage s)
-
--- | Environment description
-data Descr env sto where
- DTop :: Descr '[] '[]
- DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
-deriving instance Show (Descr env sto)
-
-descrList :: Descr env sto -> SList STy env
-descrList DTop = SNil
-descrList (des `DPush` (t, _)) = t `SCons` descrList des
-
-select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
-select _ DTop = SNil
-select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
-select s@SMerge (DPush des (_, SAccum)) = select s des
-select s@SAccum (DPush des (_, SMerge)) = select s des
-select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-
--- | This could have more precise typing on the output storage.
-subDescr :: Descr env sto -> Subenv env env'
- -> (forall sto'. Descr env' sto'
- -> Subenv (Select env sto "merge") (Select env' sto' "merge")
- -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
- -> Subenv (D1E env) (D1E env')
- -> r)
- -> r
-subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, sto)) (SEYes sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
-subDescr (des `DPush` (_, sto)) (SENo sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
- SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
-
-sD1eEnv :: Descr env sto -> SList STy (D1E env)
-sD1eEnv DTop = SNil
-sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
-
-d2ace :: SList STy env -> SList STy (D2AcE env)
-d2ace SNil = SNil
-d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) d))
+ (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)
+ subtape subtape2)
+ pairs))
freezeRet :: Descr env sto
-> Ret env sto t
-> Ex (D1E env) (D2 t) -- the incoming cotangent value
-> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 e1 sub e2) d =
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) d =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0
- e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
+ e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
(ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $
- ELet ext e2' $
+ ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (sD1eEnv descr))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
+ (#d :++: #shbinds :++: #d2ace :++: #tl))
+ e2') $
expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+
+---------------------------- THE CHAD TRANSFORMATION ---------------------------
+
drev :: forall env sto t.
Descr env sto
-> Ex env t -> Ret env sto t
@@ -665,50 +717,55 @@ drev des = \case
case conv2Idx des i of
Left accI ->
Ret BTop
+ SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
(EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
Right tupI ->
Ret BTop
+ SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvOnehot (select SMerge des) tupI)
(EPair ext (ENil ext) (EVar ext (d2 t) IZ))
ELet _ rhs body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
- , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
+ <- drev des rhs
+ , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2
+ <- drev (des `DPush` (typeOf rhs, SMerge)) body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum"))
- , Refl <- lemAppendNil @body_shbinds ->
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' ->
subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
+ (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
(weakenExpr wbody0' body1)
subBoth
(ELet ext
(weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (bindingsBinds body0)
- &. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0))
+ &. #body (subList (bindingsBinds body0) subtapeBody)
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
- (#d :++: #body :++: #rhs :++: #tl))
+ (#d :++: (#body :++: #rhs) :++: #tl))
body2') $
ELet ext
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
(EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
- | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
<- retConcat des $ drev des a `SCons` drev des b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
+ subtape
(EPair ext a1 b1)
subBoth
(ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
@@ -722,28 +779,31 @@ drev des = \case
(EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))
EFst _ e
- | Ret e0 e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
+ subtape
(EFst ext e1)
sub
(ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
+ subtape
(ESnd ext e1)
sub
(ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
+ ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
EInl _ t2 e
- | Ret e0 e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des e ->
Ret e0
+ subtape
(EInl ext (d1 t2) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
@@ -753,8 +813,9 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))
EInr _ t1 e
- | Ret e0 e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des e ->
Ret e0
+ subtape
(EInr ext (d1 t1) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
@@ -763,6 +824,7 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
+{-
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
, Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e
@@ -962,7 +1024,6 @@ drev des = \case
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (bindingsBinds e0) in
let collectexpr = bindingsCollect e0 in
- -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in
Ret (she0 `BPush` (shty, she1)
`BPush` (STArr ndim tapety
,EBuild ext ndim
@@ -1129,6 +1190,7 @@ drev des = \case
EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
+-}
where
err_accum = error "Accumulator operations unsupported in the source program"