diff options
Diffstat (limited to 'src/Language.hs')
-rw-r--r-- | src/Language.hs | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/src/Language.hs b/src/Language.hs index a66b8b6..4e6d604 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -17,6 +17,7 @@ module Language ( import Array import AST +import AST.Sparse.Types import AST.Types import CHAD.Types import Data @@ -149,6 +150,9 @@ infixl 9 ! shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) shape = NEShape +length_ :: NExpr env (TArr N1 t) -> NExpr env TIx +length_ e = snd_ (shape e) + oper :: SOp a t -> NExpr env a -> NExpr env t oper = NEOp @@ -166,11 +170,17 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 -with :: forall t a env acname. KnownTy t => NExpr env (D2 t) -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a (D2 t)) -with a (n :-> b) = NEWith (knownTy @t) a n b +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b -accum :: KnownTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownTy p a b c +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) @@ -204,6 +214,10 @@ or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TB or_ = oper2 OOr infixr 2 `or_` +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + -- | The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) |