aboutsummaryrefslogtreecommitdiff
path: root/src/Language.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language.hs')
-rw-r--r--src/Language.hs22
1 files changed, 18 insertions, 4 deletions
diff --git a/src/Language.hs b/src/Language.hs
index a66b8b6..4e6d604 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -17,6 +17,7 @@ module Language (
import Array
import AST
+import AST.Sparse.Types
import AST.Types
import CHAD.Types
import Data
@@ -149,6 +150,9 @@ infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
+length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
+length_ e = snd_ (shape e)
+
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
@@ -166,11 +170,17 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
-with :: forall t a env acname. KnownTy t => NExpr env (D2 t) -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a (D2 t))
-with a (n :-> b) = NEWith (knownTy @t) a n b
+recompute :: NExpr env a -> NExpr env a
+recompute = NERecompute
+
+with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
+with a (n :-> b) = NEWith (knownMTy @t) a n b
-accum :: KnownTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil
-accum p a b c = NEAccum knownTy p a b c
+accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+
+accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+accumS p a sp b c = NEAccum knownMTy p a sp b c
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
@@ -204,6 +214,10 @@ or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TB
or_ = oper2 OOr
infixr 2 `or_`
+mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a)
+mod_ = oper2 (OMod knownScalTy)
+infixl 7 `mod_`
+
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)