summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs3
-rw-r--r--src/AST/Pretty.hs1
-rw-r--r--src/CHAD.hs2
-rw-r--r--src/Compile.hs2
-rw-r--r--src/ForwardAD/DualNumbers.hs3
-rw-r--r--src/Interpreter.hs1
-rw-r--r--src/Language.hs4
7 files changed, 16 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 652d003..9161956 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -127,6 +127,7 @@ data SOp a t where
OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
deriving instance Show (SOp a t)
opt1 :: SOp a t -> STy a
@@ -147,6 +148,7 @@ opt1 = \case
OExp t -> STScal t
OLog t -> STScal t
OIDiv t -> STPair (STScal t) (STScal t)
+ OMod t -> STPair (STScal t) (STScal t)
opt2 :: SOp a t -> STy t
opt2 = \case
@@ -166,6 +168,7 @@ opt2 = \case
OExp t -> STScal t
OLog t -> STScal t
OIDiv t -> STScal t
+ OMod t -> STScal t
typeOf :: Expr x env t -> STy t
typeOf = \case
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index da4f391..19c7cfc 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -343,6 +343,7 @@ operator ORecip{} = (Prefix, "recip")
operator OExp{} = (Prefix, "exp")
operator OLog{} = (Prefix, "log")
operator OIDiv{} = (Infix, "`div`")
+operator OMod{} = (Infix, "`mod`")
ppSTy :: Int -> STy t -> String
ppSTy d ty = ppTy d (unSTy ty)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 6ab4cfb..1126fde 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -219,6 +219,7 @@ d1op (ORecip t) e = EOp ext (ORecip t) e
d1op (OExp t) e = EOp ext (OExp t) e
d1op (OLog t) e = EOp ext (OLog t) e
d1op (OIDiv t) e = EOp ext (OIDiv t) e
+d1op (OMod t) e = EOp ext (OMod t) e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -244,6 +245,7 @@ d2op op = case op of
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
diff --git a/src/Compile.hs b/src/Compile.hs
index 2a184f7..6466065 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1271,6 +1271,7 @@ compileOpGeneral op e1 = do
OLog STF32 -> unary "logf"
OLog STF64 -> unary "log"
OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
compileOpPair op e1 e2 = do
@@ -1284,6 +1285,7 @@ compileOpPair op e1 e2 = do
OAnd -> binary "&&"
OOr -> binary "||"
OIDiv _ -> binary "/"
+ OMod _ -> binary "%"
_ -> error "compileOpPair: got unary operator"
-- | Bool: whether to ensure that the literal itself already has the appropriate type
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 9a95f81..2f94076 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -86,6 +86,9 @@ dop = \case
OIDiv t -> scalTyCase t
(case t of {})
(EOp ext (OIDiv t))
+ OMod t -> scalTyCase t
+ (case t of {})
+ (EOp ext (OMod t))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 572f2bd..58d79a5 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -184,6 +184,7 @@ interpretOp op arg = case op of
OExp st -> floatingIsFractional st $ exp arg
OLog st -> floatingIsFractional st $ log arg
OIDiv st -> integralIsIntegral st $ uncurry quot arg
+ OMod st -> integralIsIntegral st $ uncurry rem arg
where
styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
styIsEq STI32 = id
diff --git a/src/Language.hs b/src/Language.hs
index a66b8b6..cf7cc4c 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -204,6 +204,10 @@ or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TB
or_ = oper2 OOr
infixr 2 `or_`
+mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a)
+mod_ = oper2 (OMod knownScalTy)
+infixl 7 `mod_`
+
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)