summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs143
1 files changed, 132 insertions, 11 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 943f0a2..7747d46 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -262,7 +262,7 @@ vectoriseExpr :: forall prefix binds env t f.
-> Ex (Append prefix (Append binds env)) t
-> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t
vectoriseExpr prefix binds env =
- let wTarget :: Layout ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e
+ let wTarget :: Layout True ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e
-> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env)
wTarget layout =
autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env)
@@ -422,6 +422,10 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
+indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
+indexTupD1Id SZ = Refl
+indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
accumPromote :: forall dt env sto proxy r.
proxy dt
-> Descr env sto
@@ -974,6 +978,7 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
+ -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape
EBuild1 _ ne (orige :: Ex _ eltty)
| Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result
, let eltty = typeOf orige ->
@@ -1050,6 +1055,90 @@ drev des = \case
ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ)))
}}
+ EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
+ | Ret (she0 :: Bindings _ _ she_binds) she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result
+ , let eltty = typeOf orige
+ , shty :: STy shty <- tTup (sreplicate ndim tIx)
+ , Refl <- indexTupD1Id ndim ->
+ deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
+ let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
+ case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 ->
+ case assertSubenvEmpty sub of { Refl ->
+ let tapety = tapeTy (bindingsBinds e0) in
+ let collectexpr = bindingsCollect e0 in
+ -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in
+ Ret (she0 `BPush` (shty, she1)
+ `BPush` (STArr ndim tapety
+ ,EBuild ext ndim
+ (EVar ext shty IZ)
+ (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #she0 (bindingsBinds she0)
+ &. #d1env (sD1eEnv des)
+ &. #d1env' (sD1eEnv usedDes))
+ (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#ix :++: #sh :++: #she0 :++: #d1env))
+ e0)) $
+ collectexpr (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #she0 (bindingsBinds she0)
+ &. #e0 (bindingsBinds e0)
+ &. #d1env (sD1eEnv des)
+ &. #d1env' (sD1eEnv usedDes))
+ (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env)))))
+ (EBuild ext ndim
+ (EVar ext shty (IS IZ))
+ (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ))
+ (EVar ext shty IZ)) $
+ let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #she0 (bindingsBinds she0)
+ &. #e0 (bindingsBinds e0)
+ &. #tape (tapety `SCons` SNil)
+ &. #tapearr (STArr ndim tapety `SCons` SNil)
+ &. #prerebinds prerebinds
+ &. #d1env (sD1eEnv des)
+ &. #d1env' (sD1eEnv usedDes))
+ (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ ((#e0 :++: #prerebinds) :++: #tape :++: #ix :++: #tapearr :++: #sh :++: #she0 :++: #d1env))
+ e1))
+ (subenvCompose subMergeUsed proSub)
+ (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_binds) : Tup (Replicate ndim TIx) : Append she_binds (D2AcE (Select env sto "accum"))) (d2ace envPro) in
+ ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
+ &. #pro (d2ace envPro)
+ &. #ebinds (bindingsBinds e0)
+ &. #prerebinds prerebinds
+ &. #tape (tapety `SCons` SNil)
+ &. #ix (shty `SCons` SNil)
+ &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
+ &. #tapearr (STArr ndim tapety `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #shebinds (bindingsBinds she0)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #ebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#ebinds :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #shebinds :++: #d2acEnv)
+ .> wPro (bindingsBinds e0))
+ e2)
+ }}
+
EUnit _ e
| Ret e0 e1 sub e2 <- drev des e ->
Ret e0
@@ -1072,23 +1161,55 @@ drev des = \case
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
, STArr (SS n) eltty <- typeOf e ->
- Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1))
- (weakenExpr WSink (EIdx1 ext e1 ei1))
+ Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1))
+ (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)
+ (weakenExpr WSink ei1))
sub
- (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (ELet ext (ebuildUp1 n (EFst ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))))
+ (ESnd ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))))
(EVar ext (STArr n (d2 eltty)) (IS IZ))) $
weakenExpr (WCopy (WSink .> WSink)) e2)
+ EIdx _ n e ei
+ -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
+ | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
+ <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ , STArr _ eltty <- typeOf e
+ , Refl <- indexTupD1Id n ->
+ Ret (binds `BPush` (STArr n (d1 eltty), e1))
+ (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ)
+ (weakenExpr WSink ei1))
+ sub
+ (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ)))
+ (EVar ext (d2 eltty) (IS IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
+ EShape _ e
+ -- Allowed to ignore e2 here because the output of EShape is discrete,
+ -- hence we'd be passing a zero cotangent to e2 anyway.
+ | Ret e0 e1 _ _ <- drev des e
+ , STArr n _ <- typeOf e
+ , Refl <- indexTupD1Id n ->
+ Ret e0
+ (EShape ext e1)
+ (subenvNone (select SMerge des))
+ (ENil ext)
+
+ ESum1Inner _ e
+ | Ret e0 e1 sub e2 <- drev des e
+ , STArr (SS n) t <- typeOf e ->
+ Ret (e0 `BPush` (STArr (SS n) t, e1))
+ (ESum1Inner ext (EVar ext (STArr (SS n) t) IZ))
+ sub
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EShape ext (EVar ext (STArr (SS n) t) (IS IZ))))
+ (EVar ext (STArr n (d2 t)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
-- These should be the next to be implemented, I think
- ESum1Inner{} -> err_unsupported "ESum"
- EReplicate1Inner{} -> err_unsupported "EReplicate"
- EShape{} -> err_unsupported "EShape"
+ EReplicate1Inner{} -> err_unsupported "EReplicate1Inner"
EFold1Inner{} -> err_unsupported "EFold1Inner"
- EIdx{} -> err_unsupported "EIdx"
- EBuild{} -> err_unsupported "EBuild"
-
EWith{} -> err_accum
EAccum{} -> err_accum