aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs90
1 files changed, 87 insertions, 3 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index cfae98d..ec719e8 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1116,6 +1116,89 @@ drev des accumMap sd = \case
e2)
}}
+ EFold1Inner _ commut origef ex₀ earr
+ | SpArr @_ @sdElt sdElt <- sd
+ , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
+ , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
+ deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') ->
+ let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in
+ case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
+ let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
+ &. #parr (auto1 @(TArr (S n) (D1 elt)))
+ &. #px₀ (auto1 @(D1 elt))
+ &. #primal (primalTy `SCons` SNil)
+ &. #darr (auto1 @(TArr n sdElt))
+ &. #d (auto1 @(D2 elt))
+ &. #x₀abinds (bindingsBinds bindsx₀a)
+ &. #fbinds (bindingsBinds ef0)
+ &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
+ &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
+ &. #ftape (auto1 @(Tape e_tape))
+ &. #efPrerebinds efPrerebinds
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace envPro)
+ &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
+ wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a mergePrimalBindings'
+ `bpush` weakenExpr wOverPrimalBindings ex₀1
+ `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1
+ `bpush` EFold1InnerD1 ext commut
+ (letBinds (fst (weakenBindingsE (autoWeak library
+ (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ ef0)) $
+ EPair ext
+ (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ ef1)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))))
+ (EVar ext (d1 eltty) (IS IZ))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
+ (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))
+ (EFst ext (EVar ext primalTy IZ))
+ subx₀af
+ (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ elet
+ (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
+ makeAccumulators (autoWeak library #propr layout1) envPro $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut (d2M eltty)
+ (letBinds (efRebinds (IS (IS (IS IZ)))) $
+ let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in
+ elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
+ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
+ .> wPro (subList (bindingsBinds ef0) subtapeEf))
+ ef2) $
+ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
+ plus_x₀a_f
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ plus_x₀_a
+ (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2)
+ (subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (ESnd ext (evar IZ)))
+ }
+
EUnit _ e
| SpArr sdElt <- sd
, Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
@@ -1229,9 +1312,6 @@ drev des accumMap sd = \case
EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
- -- These should be the next to be implemented, I think
- EFold1Inner{} -> err_unsupported "EFold1Inner"
-
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"
@@ -1246,10 +1326,14 @@ drev des accumMap sd = \case
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
+ err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program"
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))