aboutsummaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r--src/Language/AST.hs20
1 files changed, 19 insertions, 1 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 387915b..be98ccf 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM
import Array
import AST
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -42,6 +43,9 @@ data NExpr env t where
NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c
+ NENothing :: STy t -> NExpr env (TMaybe t)
+ NEJust :: NExpr env t -> NExpr env (TMaybe t)
+ NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
@@ -68,6 +72,13 @@ data NExpr env t where
-> NExpr env a -> NExpr env b
-> NExpr env t
+ -- fake halfway checkpointing
+ NERecompute :: NExpr env t -> NExpr env t
+
+ -- accumulation effect on monoids
+ NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+
-- partiality
NEError :: STy a -> String -> NExpr env a
@@ -182,10 +193,13 @@ fromNamedExpr val = \case
NEInl t e -> EInl ext t (go e)
NEInr t e -> EInr ext t (go e)
NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
+ NENothing t -> ENothing ext t
+ NEJust e -> EJust ext (go e)
+ NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c)
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
@@ -205,6 +219,10 @@ fromNamedExpr val = \case
(fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
(fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
(go e1) (go e2)
+ NERecompute e -> ERecompute ext (go e)
+
+ NEWith t a n b -> EWith ext t (go a) (lambda val n b)
+ NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
NEError t s -> EError ext t s