diff options
| -rw-r--r-- | chad-fast.cabal | 2 | ||||
| -rw-r--r-- | src/Example/GMM.hs | 15 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 9 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 10 | ||||
| -rw-r--r-- | test/Main.hs | 140 | 
5 files changed, 140 insertions, 36 deletions
| diff --git a/chad-fast.cabal b/chad-fast.cabal index 274e497..8817718 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -65,8 +65,10 @@ test-suite test    build-depends:      chad-fast,      base, +    containers,      dependent-map,      hedgehog, +    transformers,    hs-source-dirs: test    default-language: Haskell2010    ghc-options: -Wall -threaded diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index ff37f9a..1db88bd 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -32,8 +32,16 @@ type TMat = TArr (S (S Z))  --   Master thesis at Utrecht University. (Appendix B.1)  --   <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>  --   <https://tomsmeding.com/f/master.pdf> -gmmObjective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -gmmObjective = fromNamed $ +-- +-- The 'wrong' argument, when set to True, changes the objective function to +-- one with a bug that makes a certain `build` result unused. This triggers +-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's +-- dense, even though it may be a zero (i.e. empty). The "unused" test in +-- test/Main.hs tries to isolate this test, but the wrong version of +-- gmmObjective is here to check (after that bug is fixed) whether it really +-- fixes the original bug. +gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective wrong = fromNamed $    lambda #N $ lambda #D $ lambda #K $    lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $    lambda #X $ lambda #m $ @@ -100,7 +108,8 @@ gmmObjective = fromNamed $                          if_ (#i .== #j)                            (exp (#q ! pair nil #i))                            (if_ (#i .> #j) -                            (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) +                            (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) +                                      else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j))                              0.0)          qmat q l = inline qmat' (SNil .$ q .$ l)      in let_ #k2arr (unit #k2) $ diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 67d22dd..b95385c 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -7,11 +7,12 @@  module ForwardAD where  import Data.Bifunctor (bimap) --- import Data.Foldable (toList) + +-- import Debug.Trace +-- import AST.Pretty  import Array  import AST --- import AST.Bindings  import Data  import ForwardAD.DualNumbers  import Interpreter @@ -214,6 +215,8 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =  drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)  drevByFwd env expr input dres =    let outty = typeOf expr -  in dnOnehotEnvs env input $ \dnInput -> +  in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ +     dnOnehotEnvs env input $ \dnInput -> +       -- trace (showEnv (dne env) dnInput) $         let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))         in dotprodTan outty outtan dres diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 7ef9088..0007991 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -3,7 +3,7 @@  {-# LANGUAGE UndecidableInstances #-}  module Interpreter.Rep where -import Data.List (intersperse) +import Data.List (intersperse, intercalate)  import Data.Foldable (toList)  import Data.IORef  import GHC.TypeError @@ -11,6 +11,7 @@ import GHC.TypeError  import Array  import AST  import AST.Pretty +import Data  type family Rep t where @@ -76,3 +77,10 @@ showValue _ (STScal sty) x = case sty of    STI64 -> shows x    STBool -> shows x  showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" +  where +    showEntries :: SList STy env -> SList Value env -> [String] +    showEntries SNil SNil = [] +    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs diff --git a/test/Main.hs b/test/Main.hs index 75ab11a..72b7809 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -7,11 +7,15 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-}  module Main where +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State  import Data.Bifunctor  import Data.Int (Int64) -import Data.List (intercalate) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map  import Hedgehog  import qualified Hedgehog.Gen as Gen  import qualified Hedgehog.Range as Range @@ -86,33 +90,89 @@ closeIsh :: Double -> Double -> Bool  closeIsh a b =    abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) -genShape :: SNat n -> Gen (Shape n) -genShape = \n -> do -  sh <- genShapeNaive n +data a :$ b = a :$ b deriving (Show) ; infixl :$ + +-- An empty name means "no restrictions". +data TplConstr = C String  -- ^ name; @""@ means anonymous +                   Int     -- ^ minimum value to generate + +type family DimNames n where +  DimNames Z = () +  DimNames (S Z) = TplConstr +  DimNames (S n) = DimNames n :$ TplConstr + +type family Tpl t where +  Tpl (TArr n t) = DimNames n +  Tpl (TPair a b) = (Tpl a, Tpl b) +  -- If you add equations here, don't forget to update genValue! It currently +  -- just emptyTpl's things out. +  Tpl _ = () + +data a :& b = a :& b deriving (Show) ; infixl :& + +type family TemplateE env where +  TemplateE '[] = () +  TemplateE '[t] = Tpl t +  TemplateE (t : ts) = TemplateE ts :& Tpl t + +emptyDimNames :: SNat n -> DimNames n +emptyDimNames SZ = () +emptyDimNames (SS SZ) = C "" 0 +emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 + +emptyTpl :: STy t -> Tpl t +emptyTpl (STArr n _) = emptyDimNames n +emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) +emptyTpl (STScal _) = () +emptyTpl _ = error "too lazy" + +emptyTemplateE :: SList STy env -> TemplateE env +emptyTemplateE SNil = () +emptyTemplateE (t `SCons` SNil) = emptyTpl t +emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t + +genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) +genShape = \n tpl -> do +  sh <- genShapeNaive n tpl    let sz = shapeSize sh        factor = sz `div` 100 + 1    return (shapeDiv sh factor)    where -    genShapeNaive :: SNat n -> Gen (Shape n) -    genShapeNaive SZ = return ShNil -    genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10) +    genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) +    genShapeNaive SZ () = return ShNil +    genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name +    genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name + +    genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int +    genNamedDim (C "" lo) = genDim lo +    genNamedDim (C name lo) = gets (Map.lookup name) >>= \case +      Nothing -> do +        dim <- genDim lo +        modify (Map.insert name dim) +        return dim +      Just dim -> return dim + +    genDim :: Int -> StateT (Map String Int) Gen Int +    genDim lo = Gen.integral (Range.linear lo 10)      shapeDiv :: Shape n -> Int -> Shape n      shapeDiv ShNil _ = ShNil      shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)  genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) -genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) +genArray t sh = +  Value <$> arrayGenerateLinM sh (\_ -> +              unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) -genValue :: STy a -> Gen (Value a) -genValue = \case +genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) +genValue topty tpl = case topty of    STNil -> return (Value ()) -  STPair a b -> liftV2 (,) <$> genValue a <*> genValue b -  STEither a b -> Gen.choice [liftV Left <$> genValue a -                             ,liftV Right <$> genValue b] +  STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) +  STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) +                             ,liftV Right <$> genValue b (emptyTpl b)]    STMaybe t -> Gen.choice [return (Value Nothing) -                          ,liftV Just <$> genValue t] -  STArr n t -> genShape n >>= genArray t +                          ,liftV Just <$> genValue t (emptyTpl t)] +  STArr n t -> genShape n tpl >>= lift . genArray t    STScal sty -> case sty of      STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)      STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) @@ -121,22 +181,22 @@ genValue = \case      STBool -> Gen.choice [return (Value False), return (Value True)]    STAccum{} -> error "Cannot generate inputs for accumulators" -genEnv :: SList STy env -> Gen (SList Value env) -genEnv SNil = return SNil -genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env - -showEnv :: SList STy env -> SList Value env -> String -showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" -  where -    showEntries :: SList STy env -> SList Value env -> [String] -    showEntries SNil SNil = [] -    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs +genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) +genEnv SNil () = return SNil +genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil +genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl  adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property  adTest = adTestCon (const True)  adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property -adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) +adTestCon constr term = +  let env = knownEnv +  in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) + +adTestTp :: forall env. KnownEnv env +         => TemplateE env -> Ex env (TScal TF64) -> Property +adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty)  adTestGen :: forall env. KnownEnv env            => Ex env (TScal TF64) -> Gen (SList Value env) -> Property @@ -210,6 +270,10 @@ tests = checkSequential $ Group "AD"        fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $          idx0 $ sum1i $ minimum1i #x) +  ,("unused", adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ +    let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ +      42) +    ,("neural", adTestGen Example.neural $ do        let tR = STScal STF64        let genLayer nin nout = @@ -224,7 +288,26 @@ tests = checkSequential $ Group "AD"        lay3 <- genArray tR (ShNil `ShCons` n2)        return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) -  ,("gmm", withShrinks 0 $ adTestGen Example.gmmObjective $ do +  ,("logsumexp", adTestTp (C "" 1) $ +      fromNamed $ lambda @(TArr N1 _) #vec $ body $ +      let_ #m (maximum1i #vec) $ +        log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m) + +  ,("mulmatvec", adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $ +      fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ +      idx0 $ sum1i $ +        let_ #hei (snd_ (fst_ (shape #mat))) $ +        let_ #wid (snd_ (shape #mat)) $ +          build1 #hei $ #i :-> +            idx0 (sum1i (build1 #wid $ #j :-> +                           #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))) + +  ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) + +  ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) +  ] +  where +    genGMM = do        -- The input ranges here are completely arbitrary.        let tR = STScal STF64        kN <- Gen.integral (Range.linear 1 10) @@ -245,8 +328,7 @@ tests = checkSequential $ Group "AD"                Value vm `SCons` vX `SCons`                vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`                Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` -              SNil)) -  ] +              SNil)  main :: IO ()  main = defaultMain [tests] | 
