diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 23:47:59 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 23:47:59 +0100 | 
| commit | 6e85d5b2aee0cf2c089538e74261f1d88d6b1b71 (patch) | |
| tree | 1ae5699e3b420e7a61203dfb697ad17eba38f683 /test | |
| parent | 16a836d078caefc3526031c084e2527cba0da3a8 (diff) | |
test: Start of a list of compile tests
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 66 | 
1 files changed, 57 insertions, 9 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 19bd3e6..014ad43 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -93,6 +93,28 @@ closeIsh' h a b =  closeIsh :: Double -> Double -> Bool  closeIsh = closeIsh' 1e-5 +closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool +closeIshT' _ STNil () () = True +closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y' +closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x' +closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x' +closeIshT' _ STEither{} _ _ = False +closeIshT' _ (STMaybe _) Nothing Nothing = True +closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x' +closeIshT' _ STMaybe{} _ _ = False +closeIshT' h (STArr _ a) arr1 arr2 = +  arrayShape arr1 == arrayShape arr2 && +    and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2)) +closeIshT' _ (STScal STI32) i j = i == j +closeIshT' _ (STScal STI64) i j = i == j +closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) +closeIshT' h (STScal STF64) x y = closeIsh' h x y +closeIshT' _ (STScal STBool) x y = x == y +closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" + +closeIshT :: STy t -> Rep t -> Rep t -> Bool +closeIshT = closeIshT' 1e-5 +  data a :$ b = a :$ b deriving (Show) ; infixl :$  -- An empty name means "no restrictions". @@ -189,6 +211,22 @@ genEnv SNil () = return SNil  genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil  genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl +data TypedValue t = TypedValue (STy t) (Rep t) +instance Show (TypedValue t) where +  showsPrec d (TypedValue t x) = showValue d t x + +compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree +compileTest name expr = +  let env = knownEnv +      t = typeOf expr +  in withCompiled env expr $ \fun -> +     testProperty name $ property $ do +       input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty) +       let resI = interpretOpen False input expr +       resC <- liftIO $ fun input +       let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y +       diff (TypedValue t resI) cmp (TypedValue t resC) +  adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree  adTest name = adTestCon name (const True) @@ -247,16 +285,16 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)                -> (SList Value env -> IO Double)                -> TestTree  adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = +  let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +      dtermChadS = simplifyFix dtermChad0 +      dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS +      dtermSChadS = simplifyFix dtermSChad0 +  in    withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> -  withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS -> +  withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->      testProperty "chad" $ property $ do        annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -      let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr -          dtermChadS = simplifyFix dtermChad0 -      let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS -          dtermSChadS = simplifyFix dtermSChad0 -        -- pack Text for less GC pressure (these values are retained for some reason)        diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))        diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) @@ -286,11 +324,13 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =        -- annotate (ppExpr knownEnv expr)        -- annotate (ppExpr env dtermChad0)        -- annotate (ppExpr env dtermChadS) +      annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))        diff outChad0      closeIsh outPrimal        diff outChadS      closeIsh outPrimal        diff outSChad0     closeIsh outPrimal        diff outSChadS     closeIsh outPrimal        diff outCompSChadS closeIsh outPrimal +      -- TODO: use closeIshT        let closeIshList x y = and (zipWith closeIsh x y)        diff scChad       closeIshList scFwd        diff scChadS      closeIshList scFwd @@ -338,8 +378,14 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec          idx0 (sum1i (build1 #wid $ #j :->                         #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) -tests :: TestTree -tests = testGroup "AD" +tests_Compile :: TestTree +tests_Compile = testGroup "Compile" +  [compileTest "accum f64" $ fromNamed $ lambda #x $ body $ +      with @(TScal TF64) 0.0 $ #ac :-> +        accum SAPHere nil #x #ac] + +tests_AD :: TestTree +tests_AD = testGroup "AD"    [adTest "id" $ fromNamed $ lambda #x $ body $ #x    ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x @@ -456,4 +502,6 @@ tests = testGroup "AD"        return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)  main :: IO () -main = defaultMain tests +main = defaultMain $ testGroup "All" +  [tests_Compile +  ,tests_AD] | 
