summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-27 00:01:27 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-27 00:01:27 +0100
commitadbe3c3c75ecd1a0a6f38165329694f309d6891c (patch)
tree55e8f158931810e21b33d5f668e533796f93bca0
parent75141f1c1f97fef563df2be6e512e568f922cb45 (diff)
test: Some more Compile tests (still passing, but code still broken)
-rw-r--r--test/Main.hs33
1 files changed, 28 insertions, 5 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 933629c..cc17b5d 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -219,12 +219,18 @@ instance Show (TypedValue t) where
showsPrec d (TypedValue t x) = showValue d t x
compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
-compileTest name expr =
+compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr
+
+compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree
+compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty)
+
+compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree
+compileTestGen name expr envGenerator =
let env = knownEnv
t = typeOf expr
in withCompiled env expr $ \fun ->
testProperty name $ property $ do
- input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty)
+ input <- forAllWith (showEnv env) envGenerator
let resI = interpretOpen False input expr
resC <- liftIO $ fun input
let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
@@ -383,9 +389,26 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec
tests_Compile :: TestTree
tests_Compile = testGroup "Compile"
- [compileTest "accum f64" $ fromNamed $ lambda #x $ body $
- with @(TScal TF64) 0.0 $ #ac :->
- accum SAPHere nil #x #ac]
+ [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @R 0.0 $ #ac :->
+ if_ #b (accum SAPHere nil #x #ac)
+ nil
+
+ ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @(TPair R R) nothing $ #ac :->
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
+ nil
+
+ ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $
+ let_ #len (snd_ (shape #x)) $
+ with @(TArr N1 R) nothing $ #ac :->
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac)
+ nil) $
+ let_ #_ (accum SAPHere nil (just #x) #ac) $
+ nil
+ ]
tests_AD :: TestTree
tests_AD = testGroup "AD"