summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-04 21:33:56 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-04 21:45:50 +0100
commitd751deedfdc2ba5fbeb72ede5754587a1f677835 (patch)
tree7bbb7ea6ac1482599a597feee1401d2b781b4971 /test
parentbacd70ca6ba028e935bb512aeee713943901acdd (diff)
Compile: Fix right-precedence of (*)
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs26
1 files changed, 23 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 83eaa83..b3a0795 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -66,9 +66,12 @@ gradientByCHAD' simplIters env term input =
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
+closeIsh' :: Double -> Double -> Double -> Bool
+closeIsh' h a b =
+ abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h)
+
closeIsh :: Double -> Double -> Bool
-closeIsh a b =
- abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
+closeIsh = closeIsh' 1e-5
data a :$ b = a :$ b deriving (Show) ; infixl :$
@@ -201,7 +204,7 @@ adTestGen name expr envGenerator =
let outPrimalI = interpretOpen False input expr
outPrimal <- liftIO $ getprimalfun >>= ($ input)
- diff outPrimal closeIsh outPrimalI
+ diff outPrimal (closeIsh' 1e-8) outPrimalI
let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
(outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
@@ -260,6 +263,23 @@ tests = testGroup "AD"
,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
+ ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #i (round_ #x) $
+ let_ #j (round_ #y) $
+ let_ #a1 (#x + #y) $
+ let_ #a2 (#x - #y) $
+ let_ #a3 (#x * #y) $
+ let_ #a4 (#x / (#y * #y + 1)) $
+ let_ #b1 (#i + #j) $
+ let_ #b2 (#i - #j) $
+ let_ #b3 (#i * #j) $
+ let_ #b4 (#i `idiv` (#j * #j + 1)) $
+ #a1 + #a2 + #a3 + #a4 +
+ toFloat_ (#b1 + #b2 + #b3 + #b4)
+
+ ,adTest "order-of-operations" $ fromNamed $ body $
+ toFloat_ (3 * (3 `idiv` 2)) -- Compile had a pretty-printing bug at some point
+
,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $