summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-04 21:33:56 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-04 21:45:50 +0100
commitd751deedfdc2ba5fbeb72ede5754587a1f677835 (patch)
tree7bbb7ea6ac1482599a597feee1401d2b781b4971
parentbacd70ca6ba028e935bb512aeee713943901acdd (diff)
Compile: Fix right-precedence of (*)
-rw-r--r--src/Compile.hs4
-rw-r--r--src/Interpreter.hs2
-rw-r--r--src/Language.hs3
-rw-r--r--test/Main.hs26
4 files changed, 29 insertions, 6 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 7a846d4..0e6eee7 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -170,9 +170,9 @@ printCExpr d = \case
,(">", (5, (6, 6)))
,("<=", (5, (6, 6)))
,(">=", (5, (6, 6)))
- ,("+", (6, (6, 6)))
+ ,("+", (6, (6, 7)))
,("-", (6, (6, 7)))
- ,("*", (7, (7, 7)))
+ ,("*", (7, (7, 8)))
,("/", (7, (7, 8)))
,("%", (7, (7, 8)))]
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index dd558fe..d80a76e 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -173,7 +173,7 @@ interpretOp op arg = case op of
ORecip st -> floatingIsFractional st $ recip arg
OExp st -> floatingIsFractional st $ exp arg
OLog st -> floatingIsFractional st $ log arg
- OIDiv st -> integralIsIntegral st $ uncurry div arg
+ OIDiv st -> integralIsIntegral st $ uncurry quot arg
where
styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
styIsEq STI32 = id
diff --git a/src/Language.hs b/src/Language.hs
index 70cc4f9..810a889 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -182,9 +182,11 @@ not_ = oper ONot
and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
and_ = oper2 OAnd
+infixr 3 `and_`
or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
or_ = oper2 OOr
+infixr 2 `or_`
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
@@ -198,3 +200,4 @@ toFloat_ = oper OToFl64
idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t)
idiv = oper2 (OIDiv knownScalTy)
+infixl 7 `idiv`
diff --git a/test/Main.hs b/test/Main.hs
index 83eaa83..b3a0795 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -66,9 +66,12 @@ gradientByCHAD' simplIters env term input =
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
+closeIsh' :: Double -> Double -> Double -> Bool
+closeIsh' h a b =
+ abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h)
+
closeIsh :: Double -> Double -> Bool
-closeIsh a b =
- abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
+closeIsh = closeIsh' 1e-5
data a :$ b = a :$ b deriving (Show) ; infixl :$
@@ -201,7 +204,7 @@ adTestGen name expr envGenerator =
let outPrimalI = interpretOpen False input expr
outPrimal <- liftIO $ getprimalfun >>= ($ input)
- diff outPrimal closeIsh outPrimalI
+ diff outPrimal (closeIsh' 1e-8) outPrimalI
let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
(outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
@@ -260,6 +263,23 @@ tests = testGroup "AD"
,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
+ ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #i (round_ #x) $
+ let_ #j (round_ #y) $
+ let_ #a1 (#x + #y) $
+ let_ #a2 (#x - #y) $
+ let_ #a3 (#x * #y) $
+ let_ #a4 (#x / (#y * #y + 1)) $
+ let_ #b1 (#i + #j) $
+ let_ #b2 (#i - #j) $
+ let_ #b3 (#i * #j) $
+ let_ #b4 (#i `idiv` (#j * #j + 1)) $
+ #a1 + #a2 + #a3 + #a4 +
+ toFloat_ (#b1 + #b2 + #b3 + #b4)
+
+ ,adTest "order-of-operations" $ fromNamed $ body $
+ toFloat_ (3 * (3 `idiv` 2)) -- Compile had a pretty-printing bug at some point
+
,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $