diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 88 |
1 files changed, 56 insertions, 32 deletions
diff --git a/test/Main.hs b/test/Main.hs index afbd79b..f5e4a3c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -42,6 +42,18 @@ import Language import Simplify +data TypedValue t = TypedValue (STy t) (Rep t) +instance Show (TypedValue t) where + showsPrec d (TypedValue t x) = showValue d t x + +data TypedEnv env = TypedEnv (SList STy env) (SList Value env) +instance Show (TypedEnv env) where + show (TypedEnv env xs) = showEnv env xs + +unTypedEnv :: TypedEnv env -> SList Value env +unTypedEnv (TypedEnv _ xs) = xs + + data SimplIters = SimplIters Int | SimplFix deriving (Show) @@ -67,6 +79,7 @@ gradientByCHAD' simplIters env term input = gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 +-- | Generate input tangents for this primal extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) extendDN STNil () = pure () extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y @@ -82,6 +95,9 @@ extendDN (STScal sty) x = case sty of STI64 -> pure x STBool -> pure x extendDN (STAccum _) _ = error "Accumulators not supported in input program" +extendDN (STLEither _ _) Nothing = pure Nothing +extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x +extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env)) extendDNE SNil SNil = pure SNil @@ -112,10 +128,19 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) closeIshT' h (STScal STF64) x y = closeIsh' h x y closeIshT' _ (STScal STBool) x y = x == y closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" +closeIshT' _ (STLEither _ _) Nothing Nothing = True +closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x' +closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y' +closeIshT' _ STLEither{} _ _ = False closeIshT :: STy t -> Rep t -> Rep t -> Bool closeIshT = closeIshT' 1e-5 +closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool +closeIshE SNil SNil SNil = True +closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) = + closeIshT t x y && closeIshE env xs ys + data a :$ b = a :$ b deriving (Show) ; infixl :$ -- | The type index is just a marker that helps typed holes show what (type of) @@ -218,6 +243,9 @@ genValue topty tpl = case topty of STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" + STLEither a b -> Gen.frequency [(1, pure (Value Nothing)) + ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a)) + ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))] where genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t) genInt = do @@ -237,10 +265,6 @@ genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl -data TypedValue t = TypedValue (STy t) (Rep t) -instance Show (TypedValue t) where - showsPrec d (TypedValue t x) = showValue d t x - compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr @@ -337,22 +361,22 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) unpackGrad = unTup vUnpair (d2e env) . Value - let scFwd = tanEScalars env $ gradientByForward fwdartifactC input + let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - scChad = tanEScalars env $ toTanE env input gradChad0 - scChadS = tanEScalars env $ toTanE env input gradChadS - scSChad = tanEScalars env $ toTanE env input gradSChad0 - scSChadS = tanEScalars env $ toTanE env input gradSChadS + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) - let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS + let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) + -- annotate (showEnv (d2e env) gradChad0) + -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) @@ -362,13 +386,12 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = diff outSChad0 closeIsh outPrimal diff outSChadS closeIsh outPrimal diff outCompSChadS closeIsh outPrimal - -- TODO: use closeIshT - let closeIshList x y = and (zipWith closeIsh x y) - diff scChad closeIshList scFwd - diff scChadS closeIshList scFwd - diff scSChad closeIshList scFwd - diff scSChadS closeIshList scFwd - diff scCompSChadS closeIshList scFwd + let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansCompSChadS closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) @@ -478,18 +501,25 @@ tests_Compile = testGroup "Compile" nil ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TPair R R) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ + with @(TPair R R) (pair 0.0 0.0) $ #ac :-> + let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (accum SAPHere nil #x #ac) $ + let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + nil + + ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @(TMaybe (TPair R R)) nothing $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ nil - ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ + ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ - with @(TVec R) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) + with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) nil) $ - let_ #_ (accum SAPHere nil (just #x) #ac) $ + let_ #_ (accum SAPHere nil #x #ac) $ nil ] @@ -567,8 +597,6 @@ tests_AD = testGroup "AD" ,adTestGen "neural" Example.neural gen_neural - ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural - ,adTestTp "logsumexp" (C "" 1) $ fromNamed $ lambda @(TVec _) #vec $ body $ let_ #m (maximum1i #vec) $ @@ -578,11 +606,7 @@ tests_AD = testGroup "AD" ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm - ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm - ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm - - ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm ] main :: IO () |