diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 21:33:56 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 21:45:50 +0100 |
commit | d751deedfdc2ba5fbeb72ede5754587a1f677835 (patch) | |
tree | 7bbb7ea6ac1482599a597feee1401d2b781b4971 | |
parent | bacd70ca6ba028e935bb512aeee713943901acdd (diff) |
Compile: Fix right-precedence of (*)
-rw-r--r-- | src/Compile.hs | 4 | ||||
-rw-r--r-- | src/Interpreter.hs | 2 | ||||
-rw-r--r-- | src/Language.hs | 3 | ||||
-rw-r--r-- | test/Main.hs | 26 |
4 files changed, 29 insertions, 6 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 7a846d4..0e6eee7 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -170,9 +170,9 @@ printCExpr d = \case ,(">", (5, (6, 6))) ,("<=", (5, (6, 6))) ,(">=", (5, (6, 6))) - ,("+", (6, (6, 6))) + ,("+", (6, (6, 7))) ,("-", (6, (6, 7))) - ,("*", (7, (7, 7))) + ,("*", (7, (7, 8))) ,("/", (7, (7, 8))) ,("%", (7, (7, 8)))] diff --git a/src/Interpreter.hs b/src/Interpreter.hs index dd558fe..d80a76e 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -173,7 +173,7 @@ interpretOp op arg = case op of ORecip st -> floatingIsFractional st $ recip arg OExp st -> floatingIsFractional st $ exp arg OLog st -> floatingIsFractional st $ log arg - OIDiv st -> integralIsIntegral st $ uncurry div arg + OIDiv st -> integralIsIntegral st $ uncurry quot arg where styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r styIsEq STI32 = id diff --git a/src/Language.hs b/src/Language.hs index 70cc4f9..810a889 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -182,9 +182,11 @@ not_ = oper ONot and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) and_ = oper2 OAnd +infixr 3 `and_` or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) or_ = oper2 OOr +infixr 2 `or_` -- | The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t @@ -198,3 +200,4 @@ toFloat_ = oper OToFl64 idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) idiv = oper2 (OIDiv knownScalTy) +infixl 7 `idiv` diff --git a/test/Main.hs b/test/Main.hs index 83eaa83..b3a0795 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -66,9 +66,12 @@ gradientByCHAD' simplIters env term input = gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 +closeIsh' :: Double -> Double -> Double -> Bool +closeIsh' h a b = + abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h) + closeIsh :: Double -> Double -> Bool -closeIsh a b = - abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) +closeIsh = closeIsh' 1e-5 data a :$ b = a :$ b deriving (Show) ; infixl :$ @@ -201,7 +204,7 @@ adTestGen name expr envGenerator = let outPrimalI = interpretOpen False input expr outPrimal <- liftIO $ getprimalfun >>= ($ input) - diff outPrimal closeIsh outPrimalI + diff outPrimal (closeIsh' 1e-8) outPrimalI let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS @@ -260,6 +263,23 @@ tests = testGroup "AD" ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x + ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $ + let_ #i (round_ #x) $ + let_ #j (round_ #y) $ + let_ #a1 (#x + #y) $ + let_ #a2 (#x - #y) $ + let_ #a3 (#x * #y) $ + let_ #a4 (#x / (#y * #y + 1)) $ + let_ #b1 (#i + #j) $ + let_ #b2 (#i - #j) $ + let_ #b3 (#i * #j) $ + let_ #b4 (#i `idiv` (#j * #j + 1)) $ + #a1 + #a2 + #a3 + #a4 + + toFloat_ (#b1 + #b2 + #b3 + #b4) + + ,adTest "order-of-operations" $ fromNamed $ body $ + toFloat_ (3 * (3 `idiv` 2)) -- Compile had a pretty-printing bug at some point + ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x) ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $ |