diff options
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 187 |
1 files changed, 139 insertions, 48 deletions
diff --git a/test/Main.hs b/test/Main.hs index 7dbafab..20b4ef0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -41,6 +41,9 @@ import Language import Simplify +type R = TScal TF64 + + data SimplIters = SimplIters Int | SimplFix deriving (Show) @@ -51,19 +54,19 @@ simplifyIters iters env | Dict <- envKnown env = SimplFix -> simplifyFix -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) +gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) +gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' simplIters env term input = second (second (toTanE env input)) $ gradientByCHAD simplIters env term input -gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) @@ -93,6 +96,28 @@ closeIsh' h a b = closeIsh :: Double -> Double -> Bool closeIsh = closeIsh' 1e-5 +closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool +closeIshT' _ STNil () () = True +closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y' +closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x' +closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x' +closeIshT' _ STEither{} _ _ = False +closeIshT' _ (STMaybe _) Nothing Nothing = True +closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x' +closeIshT' _ STMaybe{} _ _ = False +closeIshT' h (STArr _ a) arr1 arr2 = + arrayShape arr1 == arrayShape arr2 && + and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2)) +closeIshT' _ (STScal STI32) i j = i == j +closeIshT' _ (STScal STI64) i j = i == j +closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) +closeIshT' h (STScal STF64) x y = closeIsh' h x y +closeIshT' _ (STScal STBool) x y = x == y +closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" + +closeIshT :: STy t -> Rep t -> Rep t -> Bool +closeIshT = closeIshT' 1e-5 + data a :$ b = a :$ b deriving (Show) ; infixl :$ -- An empty name means "no restrictions". @@ -139,7 +164,7 @@ genShape = \n tpl -> do sh <- genShapeNaive n tpl let sz = shapeSize sh factor = sz `div` 100 + 1 - return (shapeDiv sh factor) + return (shapeDiv sh tpl factor) where genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShapeNaive SZ () = return ShNil @@ -156,11 +181,12 @@ genShape = \n tpl -> do Just dim -> return dim genDim :: Int -> StateT (Map String Int) Gen Int - genDim lo = Gen.integral (Range.linear lo 10) + genDim lo = Gen.integral (Range.linear lo (lo+10)) - shapeDiv :: Shape n -> Int -> Shape n - shapeDiv ShNil _ = ShNil - shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) + shapeDiv :: Shape n -> DimNames n -> Int -> Shape n + shapeDiv ShNil _ _ = ShNil + shapeDiv (ShNil `ShCons` n) (C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = @@ -189,20 +215,42 @@ genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl -adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree +data TypedValue t = TypedValue (STy t) (Rep t) +instance Show (TypedValue t) where + showsPrec d (TypedValue t x) = showValue d t x + +compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree +compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr + +compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree +compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty) + +compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree +compileTestGen name expr envGenerator = + let env = knownEnv + t = typeOf expr + in withCompiled env expr $ \fun -> + testProperty name $ property $ do + input <- forAllWith (showEnv env) envGenerator + let resI = interpretOpen False input expr + resC <- liftIO $ fun input + let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y + diff (TypedValue t resI) cmp (TypedValue t resC) + +adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree adTest name = adTestCon name (const True) -adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree +adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree adTestCon name constr term = let env = knownEnv in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) adTestTp :: forall env. KnownEnv env - => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree + => TestName -> TemplateE env -> Ex env R -> TestTree adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env - => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree + => TestName -> Ex env R -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = let env = knownEnv @env exprS = simplifyFix expr @@ -210,11 +258,11 @@ adTestGen name expr envGenerator = withCompiled env (simplifyFix expr) $ \primalSfun -> testGroupCollapse name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun - ,adTestGenFwd env envGenerator expr exprS + ,adTestGenFwd env envGenerator exprS ,adTestGenChad env envGenerator expr exprS primalSfun] adTestGenPrimal :: SList STy env -> Gen (SList Value env) - -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> (SList Value env -> IO Double) -> TestTree adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = @@ -230,32 +278,33 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = diff outPrimalSI (closeIsh' 1e-8) outPrimalSC adTestGenFwd :: SList STy env -> Gen (SList Value env) - -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> Ex env R -> TestTree -adTestGenFwd env envGenerator expr exprS = +adTestGenFwd env envGenerator exprS = withCompiled (dne env) (dfwdDN exprS) $ \dnfun -> testProperty "compile fwdAD" $ property $ do input <- forAllWith (showEnv env) envGenerator dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input - let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr) + let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS) (outDNC1, outDNC2) <- liftIO $ dnfun dinput diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) - -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr + dtermChadS = simplifyFix dtermChad0 + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermSChadS = simplifyFix dtermSChad0 + in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> + withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> testProperty "chad" $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr - dtermChadS = simplifyFix dtermChad0 - let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS - dtermSChadS = simplifyFix dtermSChad0 - -- pack Text for less GC pressure (these values are retained for some reason) diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) @@ -268,44 +317,52 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let scFwd = tanEScalars env $ gradientByForward fwdartifactC input - let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS - scChad = tanEScalars env $ toTanE env input gradChad0 - scChadS = tanEScalars env $ toTanE env input gradChadS - scSChad = tanEScalars env $ toTanE env input gradSChad0 + scChad = tanEScalars env $ toTanE env input gradChad0 + scChadS = tanEScalars env $ toTanE env input gradChadS + scSChad = tanEScalars env $ toTanE env input gradSChad0 scSChadS = tanEScalars env $ toTanE env input gradSChadS + (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) + let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd + annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outCompSChadS closeIsh outPrimal + -- TODO: use closeIshT + let closeIshList x y = and (zipWith closeIsh x y) + diff scChad closeIshList scFwd + diff scChadS closeIshList scFwd + diff scSChad closeIshList scFwd + diff scSChadS closeIshList scFwd + diff scCompSChadS closeIshList scFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_build1_sum :: Ex '[TArr N1 R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx -term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) +term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p -term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_sparse :: Ex '[TArr N1 R] R term_sparse = fromNamed $ lambda #inp $ body $ let_ #n (snd_ (shape #inp)) $ let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $ @@ -314,7 +371,7 @@ term_sparse = fromNamed $ lambda #inp $ body $ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) -term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_regression_simpl1 :: Ex '[TArr N1 R] R term_regression_simpl1 = fromNamed $ lambda #q $ body $ idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :-> let_ #j (snd_ #idx) $ @@ -322,7 +379,7 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $ (#q ! pair nil 0) (if_ (#j .== #j) 1.0 2.0) -term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64) +term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ @@ -331,8 +388,31 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) -tests :: TestTree -tests = testGroup "AD" +tests_Compile :: TestTree +tests_Compile = testGroup "Compile" + [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @R 0.0 $ #ac :-> + if_ #b (accum SAPHere nil #x #ac) + nil + + ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @(TPair R R) nothing $ #ac :-> + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ + let_ #_ (accum SAPHere nil #x #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ + nil + + ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ + let_ #len (snd_ (shape #x)) $ + with @(TArr N1 R) nothing $ #ac :-> + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) + nil) $ + let_ #_ (accum SAPHere nil (just #x) #ac) $ + nil + ] + +tests_AD :: TestTree +tests_AD = testGroup "AD" [adTest "id" $ fromNamed $ lambda #x $ body $ #x ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x @@ -361,7 +441,7 @@ tests = testGroup "AD" ,adTest "pairs" term_pairs - ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $ + ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ @@ -375,14 +455,14 @@ tests = testGroup "AD" build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ - fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + fromNamed $ lambda @(TArr N2 R) #x $ body $ idx0 $ sum1i $ maximum1i #x ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ - fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + fromNamed $ lambda @(TArr N2 R) #x $ body $ idx0 $ sum1i $ minimum1i #x - ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ + ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42 @@ -391,6 +471,15 @@ tests = testGroup "AD" -- Regression test for a simplifier bug (89b78d4) ,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1 + -- Regression test for refcounts when indexing in nested arrays + ,adTestTp "regression-idx1" (C "" 1 :$ C "" 1) $ + fromNamed $ lambda @(TArr N2 R) #L $ body $ + if_ (const_ @TI64 1 Language..> 0) + (idx0 $ sum1i (build1 1 $ #_ :-> + idx0 (sum1i (build1 1 $ #_ :-> + #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0)))) + 42 + ,adTestGen "neural" Example.neural genNeural ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural @@ -449,4 +538,6 @@ tests = testGroup "AD" return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) main :: IO () -main = defaultMain tests +main = defaultMain $ testGroup "All" + [tests_Compile + ,tests_AD] |