diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 22:40:54 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 22:40:54 +0100 |
commit | a46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (patch) | |
tree | 1f00fa82540f4a54ddbf45fc6e5717b6dd8d5f94 /src | |
parent | 4d573fa32997a8e4824bf8326fb675d0c195b1ac (diff) |
Test gmm
Diffstat (limited to 'src')
-rw-r--r-- | src/Example/GMM.hs | 15 | ||||
-rw-r--r-- | src/ForwardAD.hs | 9 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 10 |
3 files changed, 27 insertions, 7 deletions
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index ff37f9a..1db88bd 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -32,8 +32,16 @@ type TMat = TArr (S (S Z)) -- Master thesis at Utrecht University. (Appendix B.1) -- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y> -- <https://tomsmeding.com/f/master.pdf> -gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -gmmObjective = fromNamed $ +-- +-- The 'wrong' argument, when set to True, changes the objective function to +-- one with a bug that makes a certain `build` result unused. This triggers +-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's +-- dense, even though it may be a zero (i.e. empty). The "unused" test in +-- test/Main.hs tries to isolate this test, but the wrong version of +-- gmmObjective is here to check (after that bug is fixed) whether it really +-- fixes the original bug. +gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective wrong = fromNamed $ lambda #N $ lambda #D $ lambda #K $ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ lambda #X $ lambda #m $ @@ -100,7 +108,8 @@ gmmObjective = fromNamed $ if_ (#i .== #j) (exp (#q ! pair nil #i)) (if_ (#i .> #j) - (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) + (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) + else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j)) 0.0) qmat q l = inline qmat' (SNil .$ q .$ l) in let_ #k2arr (unit #k2) $ diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 67d22dd..b95385c 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -7,11 +7,12 @@ module ForwardAD where import Data.Bifunctor (bimap) --- import Data.Foldable (toList) + +-- import Debug.Trace +-- import AST.Pretty import Array import AST --- import AST.Bindings import Data import ForwardAD.DualNumbers import Interpreter @@ -214,6 +215,8 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwd env expr input dres = let outty = typeOf expr - in dnOnehotEnvs env input $ \dnInput -> + in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) in dotprodTan outty outtan dres diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 7ef9088..0007991 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -3,7 +3,7 @@ {-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where -import Data.List (intersperse) +import Data.List (intersperse, intercalate) import Data.Foldable (toList) import Data.IORef import GHC.TypeError @@ -11,6 +11,7 @@ import GHC.TypeError import Array import AST import AST.Pretty +import Data type family Rep t where @@ -76,3 +77,10 @@ showValue _ (STScal sty) x = case sty of STI64 -> shows x STBool -> shows x showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" + where + showEntries :: SList STy env -> SList Value env -> [String] + showEntries SNil SNil = [] + showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs |