summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Example/GMM.hs15
-rw-r--r--src/ForwardAD.hs9
-rw-r--r--src/Interpreter/Rep.hs10
3 files changed, 27 insertions, 7 deletions
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index ff37f9a..1db88bd 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -32,8 +32,16 @@ type TMat = TArr (S (S Z))
-- Master thesis at Utrecht University. (Appendix B.1)
-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
-- <https://tomsmeding.com/f/master.pdf>
-gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
-gmmObjective = fromNamed $
+--
+-- The 'wrong' argument, when set to True, changes the objective function to
+-- one with a bug that makes a certain `build` result unused. This triggers
+-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
+-- dense, even though it may be a zero (i.e. empty). The "unused" test in
+-- test/Main.hs tries to isolate this test, but the wrong version of
+-- gmmObjective is here to check (after that bug is fixed) whether it really
+-- fixes the original bug.
+gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
+gmmObjective wrong = fromNamed $
lambda #N $ lambda #D $ lambda #K $
lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $
lambda #X $ lambda #m $
@@ -100,7 +108,8 @@ gmmObjective = fromNamed $
if_ (#i .== #j)
(exp (#q ! pair nil #i))
(if_ (#i .> #j)
- (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j)
+ (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j)
+ else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j))
0.0)
qmat q l = inline qmat' (SNil .$ q .$ l)
in let_ #k2arr (unit #k2) $
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 67d22dd..b95385c 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -7,11 +7,12 @@
module ForwardAD where
import Data.Bifunctor (bimap)
--- import Data.Foldable (toList)
+
+-- import Debug.Trace
+-- import AST.Pretty
import Array
import AST
--- import AST.Bindings
import Data
import ForwardAD.DualNumbers
import Interpreter
@@ -214,6 +215,8 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
drevByFwd env expr input dres =
let outty = typeOf expr
- in dnOnehotEnvs env input $ \dnInput ->
+ in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
in dotprodTan outty outtan dres
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 7ef9088..0007991 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -3,7 +3,7 @@
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
-import Data.List (intersperse)
+import Data.List (intersperse, intercalate)
import Data.Foldable (toList)
import Data.IORef
import GHC.TypeError
@@ -11,6 +11,7 @@ import GHC.TypeError
import Array
import AST
import AST.Pretty
+import Data
type family Rep t where
@@ -76,3 +77,10 @@ showValue _ (STScal sty) x = case sty of
STI64 -> shows x
STBool -> shows x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">"
+
+showEnv :: SList STy env -> SList Value env -> String
+showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
+ where
+ showEntries :: SList STy env -> SList Value env -> [String]
+ showEntries SNil SNil = []
+ showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs