diff options
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r-- | src/Language/AST.hs | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 387915b..be98ccf 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import AST.Sparse.Types import CHAD.Types import Data @@ -42,6 +43,9 @@ data NExpr env t where NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c + NENothing :: STy t -> NExpr env (TMaybe t) + NEJust :: NExpr env t -> NExpr env (TMaybe t) + NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) @@ -68,6 +72,13 @@ data NExpr env t where -> NExpr env a -> NExpr env b -> NExpr env t + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + + -- accumulation effect on monoids + NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil + -- partiality NEError :: STy a -> String -> NExpr env a @@ -182,10 +193,13 @@ fromNamedExpr val = \case NEInl t e -> EInl ext t (go e) NEInr t e -> EInr ext t (go e) NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) + NENothing t -> ENothing ext t + NEJust e -> EJust ext (go e) + NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c) + NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) @@ -205,6 +219,10 @@ fromNamedExpr val = \case (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) + + NEWith t a n b -> EWith ext t (go a) (lambda val n b) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s |