diff options
| -rw-r--r-- | test/Main.hs | 41 | 
1 files changed, 22 insertions, 19 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 014ad43..933629c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -41,6 +41,9 @@ import Language  import Simplify +type R = TScal TF64 + +  data SimplIters = SimplIters Int | SimplFix    deriving (Show) @@ -51,19 +54,19 @@ simplifyIters iters env | Dict <- envKnown env =      SimplFix -> simplifyFix  -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) +gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))  gradientByCHAD simplIters env term input =    let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term        (out, grad) = interpretOpen False input dterm    in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))  -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) +gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))  gradientByCHAD' simplIters env term input =    second (second (toTanE env input)) $      gradientByCHAD simplIters env term input -gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)  gradientByForward art input = drevByFwd art input 1.0  extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) @@ -227,20 +230,20 @@ compileTest name expr =         let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y         diff (TypedValue t resI) cmp (TypedValue t resC) -adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree +adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree  adTest name = adTestCon name (const True) -adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree +adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree  adTestCon name constr term =    let env = knownEnv    in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))  adTestTp :: forall env. KnownEnv env -         => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree +         => TestName -> TemplateE env -> Ex env R -> TestTree  adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)  adTestGen :: forall env. KnownEnv env -          => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree +          => TestName -> Ex env R -> Gen (SList Value env) -> TestTree  adTestGen name expr envGenerator =    let env = knownEnv @env        exprS = simplifyFix expr @@ -252,7 +255,7 @@ adTestGen name expr envGenerator =         ,adTestGenChad env envGenerator expr exprS primalSfun]  adTestGenPrimal :: SList STy env -> Gen (SList Value env) -                -> Ex env (TScal TF64) -> Ex env (TScal TF64) +                -> Ex env R -> Ex env R                  -> (SList Value env -> IO Double) -> (SList Value env -> IO Double)                  -> TestTree  adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = @@ -268,7 +271,7 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =      diff outPrimalSI (closeIsh' 1e-8) outPrimalSC  adTestGenFwd :: SList STy env -> Gen (SList Value env) -             -> Ex env (TScal TF64) +             -> Ex env R               -> TestTree  adTestGenFwd env envGenerator exprS =    withCompiled (dne env) (dfwdDN exprS) $ \dnfun -> @@ -281,7 +284,7 @@ adTestGenFwd env envGenerator exprS =        diff outDNI2 (closeIsh' 1e-8) outDNC2  adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) -              -> Ex env (TScal TF64) -> Ex env (TScal TF64) +              -> Ex env R -> Ex env R                -> (SList Value env -> IO Double)                -> TestTree  adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = @@ -341,18 +344,18 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =  withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree  withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_build1_sum :: Ex '[TArr N1 R] R  term_build1_sum = fromNamed $ lambda #x $ body $    idx0 $ sum1i $      build (SS SZ) (shape #x) $ #idx :-> #x ! #idx -term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) +term_pairs :: Ex [R, R] R  term_pairs = fromNamed $ lambda #x $ lambda #y $ body $    let_ #p (pair #x #y) $    let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $      fst_ #q * #x + snd_ #q * fst_ #p -term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_sparse :: Ex '[TArr N1 R] R  term_sparse = fromNamed $ lambda #inp $ body $    let_ #n (snd_ (shape #inp)) $    let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $ @@ -361,7 +364,7 @@ term_sparse = fromNamed $ lambda #inp $ body $    let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $      idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) -term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_regression_simpl1 :: Ex '[TArr N1 R] R  term_regression_simpl1 = fromNamed $ lambda #q $ body $    idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->      let_ #j (snd_ #idx) $ @@ -369,7 +372,7 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $          (#q ! pair nil 0)          (if_ (#j .== #j) 1.0 2.0) -term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64) +term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R  term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $    idx0 $ sum1i $      let_ #hei (snd_ (fst_ (shape #mat))) $ @@ -414,7 +417,7 @@ tests_AD = testGroup "AD"    ,adTest "pairs" term_pairs -  ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $ +  ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $        idx0 $ build SZ nil $ #idx :-> const_ 0.0    ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ @@ -428,14 +431,14 @@ tests_AD = testGroup "AD"          build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx    ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ -      fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ +      fromNamed $ lambda @(TArr N2 R) #x $ body $          idx0 $ sum1i $ maximum1i #x    ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ -      fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ +      fromNamed $ lambda @(TArr N2 R) #x $ body $          idx0 $ sum1i $ minimum1i #x -  ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ +  ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $        let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $          42 | 
