aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-22 22:41:09 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-22 22:41:09 +0100
commit9b7c3eea7e34f5eb0d91f93b803e853028c2cec8 (patch)
tree25b906bb49218d2743631d0c83e23717012e3b9b /src/CHAD/Drev.hs
parentb4f07c673b7c710f5861bb84e67233c63336c53d (diff)
WIP: Think about fusionfusion
Diffstat (limited to 'src/CHAD/Drev.hs')
-rw-r--r--src/CHAD/Drev.hs18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index bfa964b..eba3719 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -726,7 +726,7 @@ drev :: forall env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> Sparse (D2 t) sd
- -> Expr ValId env t -> Ret env sto sd t
+ -> Expr NoExt ValId env t -> Ret env sto sd t
drev des _ sd | isAbsent sd =
\e ->
Ret BTop
@@ -774,7 +774,7 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(ENil ext)
- ELet _ (rhs :: Expr _ _ a) body
+ ELet _ (rhs :: Expr _ _ _ a) body
| ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
, RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
, Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
@@ -872,7 +872,7 @@ drev des accumMap sd = \case
(EError ext (contribTupTy des sub') "inr<-dinl")
(inj1 $ weakenExpr (WCopy WSink) e2))
- ECase _ e (a :: Expr _ _ t) b
+ ECase _ e (a :: Expr _ _ _ t) b
| STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
@@ -1041,7 +1041,7 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty)
+ EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ _ eltty)
| SpArr @_ @sdElt sdElt <- sd
, let eltty = typeOf ef
, shty :: STy shty <- tTup (sreplicate ndim tIx)
@@ -1081,7 +1081,7 @@ drev des accumMap sd = \case
(#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap _ ef (earr :: Expr _ _ (TArr n a))
+ EMap _ ef (earr :: Expr _ _ _ (TArr n a))
| SpArr sdElt <- sd
, let STArr ndim t1 = typeOf earr
t2 = typeOf ef ->
@@ -1391,7 +1391,7 @@ deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
-> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> Sparse (D2s t) sd
- -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+ -> Expr NoExt ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
deriv_extremum extremum des accumMap sd e
| at@(STArr (SS n) t@(STScal st)) <- typeOf e
, let at' = STArr n t
@@ -1437,7 +1437,7 @@ drevScoped :: forall a s env sto sd t.
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> STy a -> Storage s -> Maybe (ValId a)
-> Sparse (D2 t) sd
- -> Expr ValId (a : env) t
+ -> Expr NoExt ValId (a : env) t
-> RetScoped env sto a s sd t
drevScoped des accumMap argty argsto argids sd expr = case argsto of
SMerge
@@ -1496,7 +1496,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
-> VarMap Int (D2AcE (Select env sto "accum"))
-> (STy a, Storage s)
-> Sparse (D2 t) dt
- -> Expr ValId (a : env) t
+ -> Expr NoExt ValId (a : env) t
-> (forall provars shbinds tape d2a'.
SList STy provars
-> Subenv (D2E (Select env sto "merge")) (D2E provars)
@@ -1574,7 +1574,7 @@ drevLambda des accumMap (argty, argsto) sd origef k =
prf1 _ _ SDiscr = Refl
-- TODO: proper primal-only transform that doesn't depend on D1 = Id
-drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal :: Descr env sto -> Expr NoExt x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
| Refl <- d1Identity (typeOf e)
, Refl <- d1eIdentity (descrList des)