diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 143 |
1 files changed, 132 insertions, 11 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 943f0a2..7747d46 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -262,7 +262,7 @@ vectoriseExpr :: forall prefix binds env t f. -> Ex (Append prefix (Append binds env)) t -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t vectoriseExpr prefix binds env = - let wTarget :: Layout ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e + let wTarget :: Layout True ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env) wTarget layout = autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env) @@ -422,6 +422,10 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto @@ -974,6 +978,7 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) + -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape EBuild1 _ ne (orige :: Ex _ eltty) | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result , let eltty = typeOf orige -> @@ -1050,6 +1055,90 @@ drev des = \case ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ))) }} + EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) + | Ret (she0 :: Bindings _ _ she_binds) she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result + , let eltty = typeOf orige + , shty :: STy shty <- tTup (sreplicate ndim tIx) + , Refl <- indexTupD1Id ndim -> + deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> + let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> + case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> + case assertSubenvEmpty sub of { Refl -> + let tapety = tapeTy (bindingsBinds e0) in + let collectexpr = bindingsCollect e0 in + -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in + Ret (she0 `BPush` (shty, she1) + `BPush` (STArr ndim tapety + ,EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #she0 (bindingsBinds she0) + &. #d1env (sD1eEnv des) + &. #d1env' (sD1eEnv usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #she0 :++: #d1env)) + e0)) $ + collectexpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #she0 (bindingsBinds she0) + &. #e0 (bindingsBinds e0) + &. #d1env (sD1eEnv des) + &. #d1env' (sD1eEnv usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env))))) + (EBuild ext ndim + (EVar ext shty (IS IZ)) + (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ)) + (EVar ext shty IZ)) $ + let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #she0 (bindingsBinds she0) + &. #e0 (bindingsBinds e0) + &. #tape (tapety `SCons` SNil) + &. #tapearr (STArr ndim tapety `SCons` SNil) + &. #prerebinds prerebinds + &. #d1env (sD1eEnv des) + &. #d1env' (sD1eEnv usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + ((#e0 :++: #prerebinds) :++: #tape :++: #ix :++: #tapearr :++: #sh :++: #she0 :++: #d1env)) + e1)) + (subenvCompose subMergeUsed proSub) + (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_binds) : Tup (Replicate ndim TIx) : Append she_binds (D2AcE (Select env sto "accum"))) (d2ace envPro) in + ESnd ext $ + uninvertTup (d2e envPro) (STArr ndim STNil) $ + makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) + &. #pro (d2ace envPro) + &. #ebinds (bindingsBinds e0) + &. #prerebinds prerebinds + &. #tape (tapety `SCons` SNil) + &. #ix (shty `SCons` SNil) + &. #darr (STArr ndim (d2 eltty) `SCons` SNil) + &. #tapearr (STArr ndim tapety `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #shebinds (bindingsBinds she0) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des))) + (#pro :++: #d :++: #ebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ebinds :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #shebinds :++: #d2acEnv) + .> wPro (bindingsBinds e0)) + e2) + }} + EUnit _ e | Ret e0 e1 sub e2 <- drev des e -> Ret e0 @@ -1072,23 +1161,55 @@ drev des = \case | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1)) - (weakenExpr WSink (EIdx1 ext e1 ei1)) + Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)) + (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) IZ) + (weakenExpr WSink ei1)) sub - (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ELet ext (ebuildUp1 n (EFst ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) + (ESnd ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EIdx _ n e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + , STArr _ eltty <- typeOf e + , Refl <- indexTupD1Id n -> + Ret (binds `BPush` (STArr n (d1 eltty), e1)) + (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ) + (weakenExpr WSink ei1)) + sub + (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ))) + (EVar ext (d2 eltty) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EShape _ e + -- Allowed to ignore e2 here because the output of EShape is discrete, + -- hence we'd be passing a zero cotangent to e2 anyway. + | Ret e0 e1 _ _ <- drev des e + , STArr n _ <- typeOf e + , Refl <- indexTupD1Id n -> + Ret e0 + (EShape ext e1) + (subenvNone (select SMerge des)) + (ENil ext) + + ESum1Inner _ e + | Ret e0 e1 sub e2 <- drev des e + , STArr (SS n) t <- typeOf e -> + Ret (e0 `BPush` (STArr (SS n) t, e1)) + (ESum1Inner ext (EVar ext (STArr (SS n) t) IZ)) + sub + (ELet ext (EReplicate1Inner ext + (ESnd ext (EShape ext (EVar ext (STArr (SS n) t) (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + -- These should be the next to be implemented, I think - ESum1Inner{} -> err_unsupported "ESum" - EReplicate1Inner{} -> err_unsupported "EReplicate" - EShape{} -> err_unsupported "EShape" + EReplicate1Inner{} -> err_unsupported "EReplicate1Inner" EFold1Inner{} -> err_unsupported "EFold1Inner" - EIdx{} -> err_unsupported "EIdx" - EBuild{} -> err_unsupported "EBuild" - EWith{} -> err_accum EAccum{} -> err_accum |