From c586e7d2343fa735a9b27e0b1a201dd2cb2bc68e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 21 Apr 2025 21:56:34 +0200 Subject: Add integer modulo operator --- src/AST.hs | 3 +++ src/AST/Pretty.hs | 1 + src/CHAD.hs | 2 ++ src/Compile.hs | 2 ++ src/ForwardAD/DualNumbers.hs | 3 +++ src/Interpreter.hs | 1 + src/Language.hs | 4 ++++ 7 files changed, 16 insertions(+) (limited to 'src') diff --git a/src/AST.hs b/src/AST.hs index 652d003..9161956 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -127,6 +127,7 @@ data SOp a t where OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) deriving instance Show (SOp a t) opt1 :: SOp a t -> STy a @@ -147,6 +148,7 @@ opt1 = \case OExp t -> STScal t OLog t -> STScal t OIDiv t -> STPair (STScal t) (STScal t) + OMod t -> STPair (STScal t) (STScal t) opt2 :: SOp a t -> STy t opt2 = \case @@ -166,6 +168,7 @@ opt2 = \case OExp t -> STScal t OLog t -> STScal t OIDiv t -> STScal t + OMod t -> STScal t typeOf :: Expr x env t -> STy t typeOf = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index da4f391..19c7cfc 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -343,6 +343,7 @@ operator ORecip{} = (Prefix, "recip") operator OExp{} = (Prefix, "exp") operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") +operator OMod{} = (Infix, "`mod`") ppSTy :: Int -> STy t -> String ppSTy d ty = ppTy d (unSTy ty) diff --git a/src/CHAD.hs b/src/CHAD.hs index 6ab4cfb..1126fde 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -219,6 +219,7 @@ d1op (ORecip t) e = EOp ext (ORecip t) e d1op (OExp t) e = EOp ext (OExp t) e d1op (OLog t) e = EOp ext (OLog t) e d1op (OIDiv t) e = EOp ext (OIDiv t) e +d1op (OMod t) e = EOp ext (OMod t) e -- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) @@ -244,6 +245,7 @@ d2op op = case op of OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) + OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) diff --git a/src/Compile.hs b/src/Compile.hs index 2a184f7..6466065 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1271,6 +1271,7 @@ compileOpGeneral op e1 = do OLog STF32 -> unary "logf" OLog STF64 -> unary "log" OIDiv _ -> binary "/" + OMod _ -> binary "%" compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr compileOpPair op e1 e2 = do @@ -1284,6 +1285,7 @@ compileOpPair op e1 e2 = do OAnd -> binary "&&" OOr -> binary "||" OIDiv _ -> binary "/" + OMod _ -> binary "%" _ -> error "compileOpPair: got unary operator" -- | Bool: whether to ensure that the literal itself already has the appropriate type diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 9a95f81..2f94076 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -86,6 +86,9 @@ dop = \case OIDiv t -> scalTyCase t (case t of {}) (EOp ext (OIDiv t)) + OMod t -> scalTyCase t + (case t of {}) + (EOp ext (OMod t)) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 572f2bd..58d79a5 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -184,6 +184,7 @@ interpretOp op arg = case op of OExp st -> floatingIsFractional st $ exp arg OLog st -> floatingIsFractional st $ log arg OIDiv st -> integralIsIntegral st $ uncurry quot arg + OMod st -> integralIsIntegral st $ uncurry rem arg where styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r styIsEq STI32 = id diff --git a/src/Language.hs b/src/Language.hs index a66b8b6..cf7cc4c 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -204,6 +204,10 @@ or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TB or_ = oper2 OOr infixr 2 `or_` +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + -- | The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) -- cgit v1.2.3-70-g09d2