From a46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 10 Nov 2024 22:40:54 +0100 Subject: Test gmm --- src/Example/GMM.hs | 15 ++++++++++++--- src/ForwardAD.hs | 9 ++++++--- src/Interpreter/Rep.hs | 10 +++++++++- 3 files changed, 27 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index ff37f9a..1db88bd 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -32,8 +32,16 @@ type TMat = TArr (S (S Z)) -- Master thesis at Utrecht University. (Appendix B.1) -- -- -gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -gmmObjective = fromNamed $ +-- +-- The 'wrong' argument, when set to True, changes the objective function to +-- one with a bug that makes a certain `build` result unused. This triggers +-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's +-- dense, even though it may be a zero (i.e. empty). The "unused" test in +-- test/Main.hs tries to isolate this test, but the wrong version of +-- gmmObjective is here to check (after that bug is fixed) whether it really +-- fixes the original bug. +gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective wrong = fromNamed $ lambda #N $ lambda #D $ lambda #K $ lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ lambda #X $ lambda #m $ @@ -100,7 +108,8 @@ gmmObjective = fromNamed $ if_ (#i .== #j) (exp (#q ! pair nil #i)) (if_ (#i .> #j) - (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) + (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) + else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j)) 0.0) qmat q l = inline qmat' (SNil .$ q .$ l) in let_ #k2arr (unit #k2) $ diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 67d22dd..b95385c 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -7,11 +7,12 @@ module ForwardAD where import Data.Bifunctor (bimap) --- import Data.Foldable (toList) + +-- import Debug.Trace +-- import AST.Pretty import Array import AST --- import AST.Bindings import Data import ForwardAD.DualNumbers import Interpreter @@ -214,6 +215,8 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwd env expr input dres = let outty = typeOf expr - in dnOnehotEnvs env input $ \dnInput -> + in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) in dotprodTan outty outtan dres diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 7ef9088..0007991 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -3,7 +3,7 @@ {-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where -import Data.List (intersperse) +import Data.List (intersperse, intercalate) import Data.Foldable (toList) import Data.IORef import GHC.TypeError @@ -11,6 +11,7 @@ import GHC.TypeError import Array import AST import AST.Pretty +import Data type family Rep t where @@ -76,3 +77,10 @@ showValue _ (STScal sty) x = case sty of STI64 -> shows x STBool -> shows x showValue _ (STAccum t) _ = showString $ "" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" + where + showEntries :: SList STy env -> SList Value env -> [String] + showEntries SNil SNil = [] + showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs -- cgit v1.2.3-70-g09d2