diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 00:01:15 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 00:01:15 +0100 |
commit | 75141f1c1f97fef563df2be6e512e568f922cb45 (patch) | |
tree | e621c50b2b43fb4050ef074365ef29016869dd35 /test | |
parent | 6e85d5b2aee0cf2c089538e74261f1d88d6b1b71 (diff) |
test: type R = TScal TF64
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/test/Main.hs b/test/Main.hs index 014ad43..933629c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -41,6 +41,9 @@ import Language import Simplify +type R = TScal TF64 + + data SimplIters = SimplIters Int | SimplFix deriving (Show) @@ -51,19 +54,19 @@ simplifyIters iters env | Dict <- envKnown env = SimplFix -> simplifyFix -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) +gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) +gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' simplIters env term input = second (second (toTanE env input)) $ gradientByCHAD simplIters env term input -gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) @@ -227,20 +230,20 @@ compileTest name expr = let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y diff (TypedValue t resI) cmp (TypedValue t resC) -adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree +adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree adTest name = adTestCon name (const True) -adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree +adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree adTestCon name constr term = let env = knownEnv in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) adTestTp :: forall env. KnownEnv env - => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree + => TestName -> TemplateE env -> Ex env R -> TestTree adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env - => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree + => TestName -> Ex env R -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = let env = knownEnv @env exprS = simplifyFix expr @@ -252,7 +255,7 @@ adTestGen name expr envGenerator = ,adTestGenChad env envGenerator expr exprS primalSfun] adTestGenPrimal :: SList STy env -> Gen (SList Value env) - -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> (SList Value env -> IO Double) -> TestTree adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = @@ -268,7 +271,7 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = diff outPrimalSI (closeIsh' 1e-8) outPrimalSC adTestGenFwd :: SList STy env -> Gen (SList Value env) - -> Ex env (TScal TF64) + -> Ex env R -> TestTree adTestGenFwd env envGenerator exprS = withCompiled (dne env) (dfwdDN exprS) $ \dnfun -> @@ -281,7 +284,7 @@ adTestGenFwd env envGenerator exprS = diff outDNI2 (closeIsh' 1e-8) outDNC2 adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) - -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = @@ -341,18 +344,18 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_build1_sum :: Ex '[TArr N1 R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx -term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) +term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p -term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_sparse :: Ex '[TArr N1 R] R term_sparse = fromNamed $ lambda #inp $ body $ let_ #n (snd_ (shape #inp)) $ let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $ @@ -361,7 +364,7 @@ term_sparse = fromNamed $ lambda #inp $ body $ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) -term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_regression_simpl1 :: Ex '[TArr N1 R] R term_regression_simpl1 = fromNamed $ lambda #q $ body $ idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :-> let_ #j (snd_ #idx) $ @@ -369,7 +372,7 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $ (#q ! pair nil 0) (if_ (#j .== #j) 1.0 2.0) -term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64) +term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ @@ -414,7 +417,7 @@ tests_AD = testGroup "AD" ,adTest "pairs" term_pairs - ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $ + ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ @@ -428,14 +431,14 @@ tests_AD = testGroup "AD" build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ - fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + fromNamed $ lambda @(TArr N2 R) #x $ body $ idx0 $ sum1i $ maximum1i #x ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ - fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + fromNamed $ lambda @(TArr N2 R) #x $ body $ idx0 $ sum1i $ minimum1i #x - ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ + ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42 |