diff options
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 515 |
1 files changed, 348 insertions, 167 deletions
diff --git a/test/Main.hs b/test/Main.hs index 20b4ef0..05597cc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,4 +1,6 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} @@ -10,38 +12,50 @@ {-# LANGUAGE UndecidableInstances #-} module Main where +import Control.Monad (when) import Control.Monad.Trans.Class (lift) -import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) -import qualified Data.Map.Strict as Map -import qualified Data.Text as T +import Data.Map.Strict qualified as Map +import Data.Text qualified as T import Hedgehog -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range import Test.Framework -import Array -import AST -import AST.Pretty -import AST.UnMonoid -import CHAD.Top -import CHAD.Types -import CHAD.Types.ToTan -import Compile -import qualified Example -import qualified Example.GMM as Example -import ForwardAD -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep -import Language -import Simplify +import CHAD.Array +import CHAD.AST hiding ((.>)) +import CHAD.AST.Count (pruneExpr) +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Drev.Types.ToTan +import CHAD.Example qualified as Example +import CHAD.Example.GMM qualified as Example +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep +import CHAD.Language +import CHAD.Simplify -type R = TScal TF64 +data TypedValue t = TypedValue (STy t) (Rep t) +instance Show (TypedValue t) where + showsPrec d (TypedValue t x) = showValue d t x + +data TypedEnv env = TypedEnv (SList STy env) (SList Value env) +instance Show (TypedEnv env) where + show (TypedEnv env xs) = showEnv env xs + +unTypedEnv :: TypedEnv env -> SList Value env +unTypedEnv (TypedEnv _ xs) = xs data SimplIters = SimplIters Int | SimplFix @@ -53,33 +67,37 @@ simplifyIters iters env | Dict <- envKnown env = SimplIters n -> simplifyN n SimplFix -> simplifyFix --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD simplIters env term input = - let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False input dterm - in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) +-- gradientByCHAD simplIters env term input = +-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term +-- (out, grad) = interpretOpen False env input dterm +-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' simplIters env term input = - second (second (toTanE env input)) $ - gradientByCHAD simplIters env term input +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) +-- gradientByCHAD' simplIters env term input = +-- second (second (toTanE env input)) $ +-- gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 +-- | Generate input tangents for this primal extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) extendDN STNil () = pure () extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y extendDN (STEither a _) (Left x) = Left <$> extendDN a x extendDN (STEither _ b) (Right y) = Right <$> extendDN b y +extendDN (STLEither _ _) Nothing = pure Nothing +extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x +extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y extendDN (STMaybe _) Nothing = pure Nothing extendDN (STMaybe t) (Just x) = Just <$> extendDN t x extendDN (STArr _ t) arr = traverse (extendDN t) arr extendDN (STScal sty) x = case sty of - STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) - STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) + STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) STI32 -> pure x STI64 -> pure x STBool -> pure x @@ -102,6 +120,10 @@ closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x' closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x' closeIshT' _ STEither{} _ _ = False +closeIshT' _ (STLEither _ _) Nothing Nothing = True +closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x' +closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y' +closeIshT' _ STLEither{} _ _ = False closeIshT' _ (STMaybe _) Nothing Nothing = True closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x' closeIshT' _ STMaybe{} _ _ = False @@ -118,20 +140,29 @@ closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" closeIshT :: STy t -> Rep t -> Rep t -> Bool closeIshT = closeIshT' 1e-5 +closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool +closeIshE SNil SNil SNil = True +closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) = + closeIshT t x y && closeIshE env xs ys + data a :$ b = a :$ b deriving (Show) ; infixl :$ --- An empty name means "no restrictions". -data TplConstr = C String -- ^ name; @""@ means anonymous - Int -- ^ minimum value to generate +-- | The type index is just a marker that helps typed holes show what (type of) +-- argument this template constraint belongs to. +data TplConstr a = C String -- ^ name; @""@ means anonymous + Int -- ^ minimum value to generate + | NC -- ^ no constraints type family DimNames n where DimNames Z = () - DimNames (S Z) = TplConstr - DimNames (S n) = DimNames n :$ TplConstr + DimNames (S Z) = TplConstr (S Z) + DimNames (S n) = DimNames n :$ TplConstr (S n) type family Tpl t where Tpl (TArr n t) = DimNames n Tpl (TPair a b) = (Tpl a, Tpl b) + Tpl (TScal TI32) = TplConstr TI32 + Tpl (TScal TI64) = TplConstr TI64 -- If you add equations here, don't forget to update genValue! It currently -- just emptyTpl's things out. Tpl _ = () @@ -145,13 +176,17 @@ type family TemplateE env where emptyDimNames :: SNat n -> DimNames n emptyDimNames SZ = () -emptyDimNames (SS SZ) = C "" 0 -emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 +emptyDimNames (SS SZ) = NC +emptyDimNames (SS n@SS{}) = emptyDimNames n :$ NC emptyTpl :: STy t -> Tpl t emptyTpl (STArr n _) = emptyDimNames n emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) -emptyTpl (STScal _) = () +emptyTpl (STScal STI32) = NC +emptyTpl (STScal STI64) = NC +emptyTpl (STScal STF32) = () +emptyTpl (STScal STF64) = () +emptyTpl (STScal STBool) = () emptyTpl _ = error "too lazy" emptyTemplateE :: SList STy env -> TemplateE env @@ -171,7 +206,8 @@ genShape = \n tpl -> do genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name - genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int + genNamedDim :: TplConstr a -> StateT (Map String Int) Gen Int + genNamedDim NC = genDim 0 genNamedDim (C "" lo) = genDim lo genNamedDim (C name lo) = gets (Map.lookup name) >>= \case Nothing -> do @@ -185,40 +221,54 @@ genShape = \n tpl -> do shapeDiv :: Shape n -> DimNames n -> Int -> Shape n shapeDiv ShNil _ _ = ShNil - shapeDiv (ShNil `ShCons` n) (C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) - shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) + shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` max lo (n `div` f) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` max lo (n `div` f) + shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) -genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) +genValue :: forall t. STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) genValue topty tpl = case topty of STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) ,liftV Right <$> genValue b (emptyTpl b)] + STLEither a b -> Gen.frequency [(1, pure (Value Nothing)) + ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a)) + ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of - STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) - STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) + STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10) + STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10) + STI32 -> genInt + STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" + where + genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t) + genInt = do + let gen lo = Gen.integral (Range.linearFrom 0 lo (max 10 (lo + 10))) + val <- case tpl of + NC -> gen (-10) + C name lo -> gets (Map.lookup name) >>= \case + Nothing -> do + val <- fromIntegral @Int @(Rep t) <$> gen lo + modify (Map.insert name (fromIntegral @(Rep t) @Int val)) + return val + Just val -> return (fromIntegral @Int @(Rep t) val) + return (Value val) genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl -data TypedValue t = TypedValue (STy t) (Rep t) -instance Show (TypedValue t) where - showsPrec d (TypedValue t x) = showValue d t x - compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr @@ -232,8 +282,8 @@ compileTestGen name expr envGenerator = in withCompiled env expr $ \fun -> testProperty name $ property $ do input <- forAllWith (showEnv env) envGenerator - let resI = interpretOpen False input expr - resC <- liftIO $ fun input + let resI = interpretOpen False env input expr + resC <- evalIO $ fun input let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y diff (TypedValue t resI) cmp (TypedValue t resC) @@ -256,10 +306,12 @@ adTestGen name expr envGenerator = exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name + groupSetCollapse $ testGroup name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS - ,adTestGenChad env envGenerator expr exprS primalSfun] + ,testGroup "chad" + [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun + ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]] adTestGenPrimal :: SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R @@ -269,12 +321,12 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = testProperty "compile primal" $ property $ do input <- forAllWith (showEnv env) envGenerator - let outPrimalI = interpretOpen False input expr - outPrimalC <- liftIO $ primalfun input + let outPrimalI = interpretOpen False env input expr + outPrimalC <- evalIO $ primalfun input diff outPrimalI (closeIsh' 1e-8) outPrimalC - let outPrimalSI = interpretOpen False input exprS - outPrimalSC <- liftIO $ primalSfun input + let outPrimalSI = interpretOpen False env input exprS + outPrimalSC <- evalIO $ primalSfun input diff outPrimalSI (closeIsh' 1e-8) outPrimalSC adTestGenFwd :: SList STy env -> Gen (SList Value env) @@ -285,84 +337,167 @@ adTestGenFwd env envGenerator exprS = testProperty "compile fwdAD" $ property $ do input <- forAllWith (showEnv env) envGenerator dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input - let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS) - (outDNC1, outDNC2) <- liftIO $ dnfun dinput + let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS) + (outDNC1, outDNC2) <- evalIO $ dnfun dinput diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 -adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) +adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree -adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 - dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermChadSUS = simplifyFix $ unMonoid dtermChadS + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 + dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS + dtermSChadSUSP = simplifyFix $ pruneExpr env dtermSChadSUS in - withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> - testProperty "chad" $ property $ do + withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS + when (not (null output)) $ + outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>" + return fun) $ \fwdartifactC -> + withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> + testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - -- pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) + -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) + let dtermChad20 = simplifyN 20 dtermChad0 + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20)) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20))) + let dtermSChad20 = simplifyN 20 dtermSChad0 + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20)) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20))) input <- forAllWith (showEnv env) envGenerator - outPrimal <- liftIO $ primalSfun input + outPrimal <- evalIO $ primalSfun input let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) unpackGrad = unTup vUnpair (d2e env) . Value - let scFwd = tanEScalars env $ gradientByForward fwdartifactC input + let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS - scChad = tanEScalars env $ toTanE env input gradChad0 - scChadS = tanEScalars env $ toTanE env input gradChadS - scSChad = tanEScalars env $ toTanE env input gradSChad0 - scSChadS = tanEScalars env $ toTanE env input gradSChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP - (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input) - let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS + (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input) + let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) + -- annotate (showEnv (d2e env) gradChad0) + -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outCompSChadS closeIsh outPrimal - -- TODO: use closeIshT - let closeIshList x y = and (zipWith closeIsh x y) - diff scChad closeIshList scFwd - diff scChadS closeIshList scFwd - diff scSChad closeIshList scFwd - diff scSChadS closeIshList scFwd - diff scCompSChadS closeIshList scFwd + annotate (ppExpr env dtermSChadSUSP) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outSChadSUSP closeIsh outPrimal + diff outCompSChadSUSP closeIsh outPrimal + let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansSChadSUSP closeIshE' tansFwd + diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +withCompiled env expr = withResource' $ do + (fun, output) <- compile env expr + when (not (null output)) $ + outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + +gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) +gen_gmm = do + -- The input ranges here are completely arbitrary. + let tR = STScal STF64 + kN <- Gen.integral (Range.linear 1 8) + kD <- Gen.integral (Range.linear 1 8) + kK <- Gen.integral (Range.linear 1 8) + let i2i64 = fromIntegral @Int @Int64 + valpha <- genArray tR (ShNil `ShCons` kK) + vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) + vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) + vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) + vgamma <- Gen.realFloat (Range.constant (-10) 10) + vm <- Gen.integral (Range.linear 0 5) + let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) + k2 = 0.5 * vgamma * vgamma + k3 = 0.42 -- don't feel like multigammaing today + return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` + Value vm `SCons` vX `SCons` + vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` + Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` + SNil) + +gen_neural :: Gen (SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]) +gen_neural = do + let tR = STScal STF64 + let genLayer nin nout = + liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + <*> genArray tR (ShNil `ShCons` nout) + nin <- Gen.integral (Range.linear 1 10) + n1 <- Gen.integral (Range.linear 1 10) + n2 <- Gen.integral (Range.linear 1 10) + input <- genArray tR (ShNil `ShCons` nin) + lay1 <- genLayer nin n1 + lay2 <- genLayer n1 n2 + lay3 <- genArray tR (ShNil `ShCons` n2) + return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) + +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx -term_build1_sum :: Ex '[TArr N1 R] R +term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p -term_sparse :: Ex '[TArr N1 R] R +term_sparse :: Ex '[TVec R] R term_sparse = fromNamed $ lambda #inp $ body $ let_ #n (snd_ (shape #inp)) $ let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $ @@ -371,7 +506,7 @@ term_sparse = fromNamed $ lambda #inp $ body $ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) -term_regression_simpl1 :: Ex '[TArr N1 R] R +term_regression_simpl1 :: Ex '[TVec R] R term_regression_simpl1 = fromNamed $ lambda #q $ body $ idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :-> let_ #j (snd_ #idx) $ @@ -379,8 +514,8 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $ (#q ! pair nil 0) (if_ (#j .== #j) 1.0 2.0) -term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R -term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ +term_mulmatvec :: Ex [TVec R, TMat R] R +term_mulmatvec = fromNamed $ lambda #mat $ lambda #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ let_ #wid (snd_ (shape #mat)) $ @@ -388,6 +523,27 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) +term_arr_rebind :: Ex '[I64, TVec R] R +term_arr_rebind = fromNamed $ lambda #a $ lambda #k $ body $ + let_ #n (if_ (#k .< length_ #a) #k (length_ #a)) $ + let_ #b (build1 #n (#i :-> #a ! pair nil #i)) $ + let_ #p (if_ (#n `mod_` 2 .== 1) + (pair #a #b) + (pair (map_ (#x :-> #x + 1) #a) #b)) $ + if_ (#n `mod_` 3 .== 1) + (idx0 (sum1i (snd_ #p))) + (let_ #b' (snd_ #p) $ + idx0 (sum1i #b') * idx0 (sum1i (map_ (#x :-> 2 * #x) #b'))) + +-- This simplifies away to a pointless test, but is helpful for debugging what +-- term_arr_rebind is supposed to test in a REPL +term_arr_rebind_simple :: Ex '[TVec R] R +term_arr_rebind_simple = fromNamed $ lambda #a $ body $ + let_ #b (build1 (length_ #a) (#i :-> 5 * (#a ! pair nil #i))) $ + let_ #c #b $ + let_ #d #c $ + idx0 (sum1i #d) + tests_Compile :: TestTree tests_Compile = testGroup "Compile" [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $ @@ -396,19 +552,33 @@ tests_Compile = testGroup "Compile" nil ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TPair R R) nothing $ #ac :-> + with @(TPair R R) (pair 0.0 0.0) $ #ac :-> let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil - ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ + ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ + let_ #_ (accum SAPHere nil #x #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ + nil + + ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ - with @(TArr N1 R) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) + with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ - let_ #_ (accum SAPHere nil (just #x) #ac) $ + let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree @@ -444,25 +614,44 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum - ,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $ + ,adTest "build2-sum" $ fromNamed $ lambda @(TMat _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + + ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#x ! pair nil #i) $ + 3 * fst_ #p + 2 * snd_ #p + + ,adTest "idx-coprod" $ term_idx_coprod + + ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ + let_ #n (snd_ (shape #arr)) $ + let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $ + if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x)) + (pair (inr (3 * #x)) (exp #x)))) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#b ! pair nil #i) $ + case_ (fst_ #p) + (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p) + (#b :-> #b * 4) + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ - fromNamed $ lambda @(TArr N2 R) #x $ body $ + fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ - fromNamed $ lambda @(TArr N2 R) #x $ body $ + fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ minimum1i #x - ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $ + ,adTest "unused" $ fromNamed $ lambda @(TVec R) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42 @@ -473,69 +662,61 @@ tests_AD = testGroup "AD" -- Regression test for refcounts when indexing in nested arrays ,adTestTp "regression-idx1" (C "" 1 :$ C "" 1) $ - fromNamed $ lambda @(TArr N2 R) #L $ body $ - if_ (const_ @TI64 1 Language..> 0) + fromNamed $ lambda @(TMat R) #L $ body $ + if_ (const_ @TI64 1 .> 0) (idx0 $ sum1i (build1 1 $ #_ :-> idx0 (sum1i (build1 1 $ #_ :-> #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0)))) 42 - ,adTestGen "neural" Example.neural genNeural + ,adTest "arr-rebind-simple" term_arr_rebind_simple + ,adTestTp "arr-rebind" (NC :& C "" 0) term_arr_rebind - ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural + ,adTestGen "neural" Example.neural gen_neural ,adTestTp "logsumexp" (C "" 1) $ - fromNamed $ lambda @(TArr N1 _) #vec $ body $ + fromNamed $ lambda @(TVec _) #vec $ body $ let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m - ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec + ,adTestTp "mulmatvec" ((NC :$ C "n" 0) :& C "n" 0) term_mulmatvec - ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM + ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm - ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM + ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm - ,adTestGen "gmm" (Example.gmmObjective False) genGMM + ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree - ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM - ] - where - genGMM = do - -- The input ranges here are completely arbitrary. - let tR = STScal STF64 - kN <- Gen.integral (Range.linear 1 8) - kD <- Gen.integral (Range.linear 1 8) - kK <- Gen.integral (Range.linear 1 8) - let i2i64 = fromIntegral @Int @Int64 - valpha <- genArray tR (ShNil `ShCons` kK) - vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) - vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) - vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) - vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) - vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - vm <- Gen.integral (Range.linear 0 5) - let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) - k2 = 0.5 * vgamma * vgamma - k3 = 0.42 -- don't feel like multigammaing today - return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` - Value vm `SCons` vX `SCons` - vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` - Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` - SNil) + ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ + idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + + ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #innern (snd_ #sh) $ + let_ #n (#innern * snd_ (fst_ #sh)) $ + let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ + -- ensure the input array to EReshape is shared + idx0 $ sum1i $ + build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) + + ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a - genNeural = do - let tR = STScal STF64 - let genLayer nin nout = - liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) - <*> genArray tR (ShNil `ShCons` nout) - nin <- Gen.integral (Range.linear 1 10) - n1 <- Gen.integral (Range.linear 1 10) - n2 <- Gen.integral (Range.linear 1 10) - input <- genArray tR (ShNil `ShCons` nin) - lay1 <- genLayer nin n1 - lay2 <- genLayer n1 n2 - lay3 <- genArray tR (ShNil `ShCons` n2) - return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) + ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a + + ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + let_ #v 2 $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a + + ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a + + ,adTest "map" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ sum1i $ map_ (#x :-> 2 * #x) #a + ] main :: IO () main = defaultMain $ testGroup "All" |
