diff options
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r-- | src/Language/AST.hs | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 84544f8..be98ccf 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import AST.Sparse.Types import CHAD.Types import Data @@ -71,9 +72,12 @@ data NExpr env t where -> NExpr env a -> NExpr env b -> NExpr env t + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + -- accumulation effect on monoids - NEWith :: STy t -> NExpr env (D2 t) -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a (D2 t)) - NEAccum :: STy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil + NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a @@ -215,9 +219,10 @@ fromNamedExpr val = \case (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s |