diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-21 21:56:34 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-21 21:56:34 +0200 | 
| commit | c586e7d2343fa735a9b27e0b1a201dd2cb2bc68e (patch) | |
| tree | 9315228980b2c5983c785889b815aaab46534052 | |
| parent | 0dc5c31b023ee7d569bbc0df7615b2bf55ba01f5 (diff) | |
Add integer modulo operator
| -rw-r--r-- | src/AST.hs | 3 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 1 | ||||
| -rw-r--r-- | src/CHAD.hs | 2 | ||||
| -rw-r--r-- | src/Compile.hs | 2 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 3 | ||||
| -rw-r--r-- | src/Interpreter.hs | 1 | ||||
| -rw-r--r-- | src/Language.hs | 4 | 
7 files changed, 16 insertions, 0 deletions
| @@ -127,6 +127,7 @@ data SOp a t where    OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)    OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)    OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) +  OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)  deriving instance Show (SOp a t)  opt1 :: SOp a t -> STy a @@ -147,6 +148,7 @@ opt1 = \case    OExp t -> STScal t    OLog t -> STScal t    OIDiv t -> STPair (STScal t) (STScal t) +  OMod t -> STPair (STScal t) (STScal t)  opt2 :: SOp a t -> STy t  opt2 = \case @@ -166,6 +168,7 @@ opt2 = \case    OExp t -> STScal t    OLog t -> STScal t    OIDiv t -> STScal t +  OMod t -> STScal t  typeOf :: Expr x env t -> STy t  typeOf = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index da4f391..19c7cfc 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -343,6 +343,7 @@ operator ORecip{} = (Prefix, "recip")  operator OExp{} = (Prefix, "exp")  operator OLog{} = (Prefix, "log")  operator OIDiv{} = (Infix, "`div`") +operator OMod{} = (Infix, "`mod`")  ppSTy :: Int -> STy t -> String  ppSTy d ty = ppTy d (unSTy ty) diff --git a/src/CHAD.hs b/src/CHAD.hs index 6ab4cfb..1126fde 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -219,6 +219,7 @@ d1op (ORecip t) e = EOp ext (ORecip t) e  d1op (OExp t) e = EOp ext (OExp t) e  d1op (OLog t) e = EOp ext (OLog t) e  d1op (OIDiv t) e = EOp ext (OIDiv t) e +d1op (OMod t) e = EOp ext (OMod t) e  -- | Both primal and dual must be duplicable expressions  data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) @@ -244,6 +245,7 @@ d2op op = case op of    OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)    OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)    OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) +  OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)    where      d2opUnArrangeInt :: SScalTy a                       -> (D2s a ~ TScal a => D2Op (TScal a) t) diff --git a/src/Compile.hs b/src/Compile.hs index 2a184f7..6466065 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1271,6 +1271,7 @@ compileOpGeneral op e1 = do      OLog STF32 -> unary "logf"      OLog STF64 -> unary "log"      OIDiv _ -> binary "/" +    OMod _ -> binary "%"  compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr  compileOpPair op e1 e2 = do @@ -1284,6 +1285,7 @@ compileOpPair op e1 e2 = do      OAnd -> binary "&&"      OOr -> binary "||"      OIDiv _ -> binary "/" +    OMod _ -> binary "%"      _ -> error "compileOpPair: got unary operator"  -- | Bool: whether to ensure that the literal itself already has the appropriate type diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 9a95f81..2f94076 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -86,6 +86,9 @@ dop = \case    OIDiv t -> scalTyCase t      (case t of {})      (EOp ext (OIDiv t)) +  OMod t -> scalTyCase t +    (case t of {}) +    (EOp ext (OMod t))    where      add :: ScalIsNumeric t ~ True          => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 572f2bd..58d79a5 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -184,6 +184,7 @@ interpretOp op arg = case op of    OExp st -> floatingIsFractional st $ exp arg    OLog st -> floatingIsFractional st $ log arg    OIDiv st -> integralIsIntegral st $ uncurry quot arg +  OMod st -> integralIsIntegral st $ uncurry rem arg    where      styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r      styIsEq STI32 = id diff --git a/src/Language.hs b/src/Language.hs index a66b8b6..cf7cc4c 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -204,6 +204,10 @@ or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TB  or_ = oper2 OOr  infixr 2 `or_` +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` +  -- | The first alternative is the True case; the second is the False case.  if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t  if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) | 
