diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-23 23:53:37 +0200 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-23 23:53:37 +0200 |
| commit | 2542f5ef42452967fec1d2376927c1f36bf263f4 (patch) | |
| tree | 717d97be4d21c4ac0355270ac81df33296b8b852 /src/AST.hs | |
| parent | f805440cf8833d238f848dd07f89b8ed5bc69e90 (diff) | |
WIP fold: Implement D[fold1i]
Still need to handle the new primitives in the rest of the library
Diffstat (limited to 'src/AST.hs')
| -rw-r--r-- | src/AST.hs | 14 |
1 files changed, 14 insertions, 0 deletions
@@ -62,6 +62,7 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) @@ -69,6 +70,19 @@ data Expr x env t where EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) t1)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores) + -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. + EFold1InnerD2 :: x (TArr (S n) t2) -> Commutative + -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0 + -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true + -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) t1) -- primal input array + -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array + -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t |
