diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 00:01:27 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 00:01:27 +0100 |
commit | adbe3c3c75ecd1a0a6f38165329694f309d6891c (patch) | |
tree | 55e8f158931810e21b33d5f668e533796f93bca0 /test/Main.hs | |
parent | 75141f1c1f97fef563df2be6e512e568f922cb45 (diff) |
test: Some more Compile tests (still passing, but code still broken)
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/test/Main.hs b/test/Main.hs index 933629c..cc17b5d 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -219,12 +219,18 @@ instance Show (TypedValue t) where showsPrec d (TypedValue t x) = showValue d t x compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree -compileTest name expr = +compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr + +compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree +compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty) + +compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree +compileTestGen name expr envGenerator = let env = knownEnv t = typeOf expr in withCompiled env expr $ \fun -> testProperty name $ property $ do - input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty) + input <- forAllWith (showEnv env) envGenerator let resI = interpretOpen False input expr resC <- liftIO $ fun input let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y @@ -383,9 +389,26 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec tests_Compile :: TestTree tests_Compile = testGroup "Compile" - [compileTest "accum f64" $ fromNamed $ lambda #x $ body $ - with @(TScal TF64) 0.0 $ #ac :-> - accum SAPHere nil #x #ac] + [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @R 0.0 $ #ac :-> + if_ #b (accum SAPHere nil #x #ac) + nil + + ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @(TPair R R) nothing $ #ac :-> + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ + let_ #_ (accum SAPHere nil #x #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ + nil + + ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ + let_ #len (snd_ (shape #x)) $ + with @(TArr N1 R) nothing $ #ac :-> + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) + nil) $ + let_ #_ (accum SAPHere nil (just #x) #ac) $ + nil + ] tests_AD :: TestTree tests_AD = testGroup "AD" |