diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 00:01:27 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 00:01:27 +0100 | 
| commit | adbe3c3c75ecd1a0a6f38165329694f309d6891c (patch) | |
| tree | 55e8f158931810e21b33d5f668e533796f93bca0 /test | |
| parent | 75141f1c1f97fef563df2be6e512e568f922cb45 (diff) | |
test: Some more Compile tests (still passing, but code still broken)
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 33 | 
1 files changed, 28 insertions, 5 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 933629c..cc17b5d 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -219,12 +219,18 @@ instance Show (TypedValue t) where    showsPrec d (TypedValue t x) = showValue d t x  compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree -compileTest name expr = +compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr + +compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree +compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty) + +compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree +compileTestGen name expr envGenerator =    let env = knownEnv        t = typeOf expr    in withCompiled env expr $ \fun ->       testProperty name $ property $ do -       input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty) +       input <- forAllWith (showEnv env) envGenerator         let resI = interpretOpen False input expr         resC <- liftIO $ fun input         let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y @@ -383,9 +389,26 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec  tests_Compile :: TestTree  tests_Compile = testGroup "Compile" -  [compileTest "accum f64" $ fromNamed $ lambda #x $ body $ -      with @(TScal TF64) 0.0 $ #ac :-> -        accum SAPHere nil #x #ac] +  [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $ +      with @R 0.0 $ #ac :-> +        if_ #b (accum SAPHere nil #x #ac) +               nil + +  ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ +      with @(TPair R R) nothing $ #ac :-> +        let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ +        let_ #_ (accum SAPHere nil #x #ac) $ +        let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ +          nil + +  ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ +      let_ #len (snd_ (shape #x)) $ +      with @(TArr N1 R) nothing $ #ac :-> +        let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) +                        nil) $ +        let_ #_ (accum SAPHere nil (just #x) #ac) $ +          nil +  ]  tests_AD :: TestTree  tests_AD = testGroup "AD" | 
