{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Main where import Control.Monad.Trans.Class (lift) import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import qualified Data.Text as T import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Tasty import Test.Tasty.Hedgehog import Array import AST import AST.Pretty import AST.UnMonoid import CHAD.Top import CHAD.Types import CHAD.Types.ToTan import Compile import qualified Example import qualified Example.GMM as Example import ForwardAD import Interpreter import Interpreter.Rep import Language import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t simplifyIters iters env | Dict <- envKnown env = case iters of SimplIters n -> simplifyN n SimplFix -> simplifyFix -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' simplIters env term input = second (second (toTanE env input)) $ gradientByCHAD simplIters env term input gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 closeIsh' :: Double -> Double -> Double -> Bool closeIsh' h a b = abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h) closeIsh :: Double -> Double -> Bool closeIsh = closeIsh' 1e-5 data a :$ b = a :$ b deriving (Show) ; infixl :$ -- An empty name means "no restrictions". data TplConstr = C String -- ^ name; @""@ means anonymous Int -- ^ minimum value to generate type family DimNames n where DimNames Z = () DimNames (S Z) = TplConstr DimNames (S n) = DimNames n :$ TplConstr type family Tpl t where Tpl (TArr n t) = DimNames n Tpl (TPair a b) = (Tpl a, Tpl b) -- If you add equations here, don't forget to update genValue! It currently -- just emptyTpl's things out. Tpl _ = () data a :& b = a :& b deriving (Show) ; infixl :& type family TemplateE env where TemplateE '[] = () TemplateE '[t] = Tpl t TemplateE (t : ts) = TemplateE ts :& Tpl t emptyDimNames :: SNat n -> DimNames n emptyDimNames SZ = () emptyDimNames (SS SZ) = C "" 0 emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 emptyTpl :: STy t -> Tpl t emptyTpl (STArr n _) = emptyDimNames n emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) emptyTpl (STScal _) = () emptyTpl _ = error "too lazy" emptyTemplateE :: SList STy env -> TemplateE env emptyTemplateE SNil = () emptyTemplateE (t `SCons` SNil) = emptyTpl t emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShape = \n tpl -> do sh <- genShapeNaive n tpl let sz = shapeSize sh factor = sz `div` 100 + 1 return (shapeDiv sh factor) where genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShapeNaive SZ () = return ShNil genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int genNamedDim (C "" lo) = genDim lo genNamedDim (C name lo) = gets (Map.lookup name) >>= \case Nothing -> do dim <- genDim lo modify (Map.insert name dim) return dim Just dim -> return dim genDim :: Int -> StateT (Map String Int) Gen Int genDim lo = Gen.integral (Range.linear lo 10) shapeDiv :: Shape n -> Int -> Shape n shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) genValue topty tpl = case topty of STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) ,liftV Right <$> genValue b (emptyTpl b)] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree adTest name = adTestCon name (const True) adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree adTestCon name constr term = let env = knownEnv in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) adTestTp :: forall env. KnownEnv env => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = withCompiled expr $ \getprimalfun -> testProperty name $ property $ do let env = knownEnv @env annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr dtermChadS = simplifyFix dtermChad0 dtermChadS20 = simplifyN 20 dtermChad0 -- pack Text for less GC pressure (these values are retained for some reason) diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) input <- forAllWith (showEnv env) envGenerator let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env) convGrad = toTanE env input . unTup vUnpair (d2e env) . Value let outPrimalI = interpretOpen False input expr outPrimal <- liftIO $ getprimalfun >>= ($ input) diff outPrimal (closeIsh' 1e-8) outPrimalI let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS scChad = envScalars env gradChad0 scChadS = envScalars env gradChadS gradFwd = gradientByForward knownEnv expr input scFwd = envScalars env gradFwd -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) diff outChadS closeIsh outChad0 diff outChadS closeIsh outPrimal diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs withCompiled :: KnownEnv env => Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled expr = withResource (compile knownEnv expr) (\_ -> pure ()) term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) term_sparse = fromNamed $ lambda #inp $ body $ let_ #n (snd_ (shape #inp)) $ let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $ let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $ let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64) term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ let_ #wid (snd_ (shape #mat)) $ build1 #hei $ #i :-> idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) tests :: TestTree tests = testGroup "AD" [adTest "id" $ fromNamed $ lambda #x $ body $ #x ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $ let_ #i (round_ #x) $ let_ #j (round_ #y) $ let_ #a1 (#x + #y) $ let_ #a2 (#x - #y) $ let_ #a3 (#x * #y) $ let_ #a4 (#x / (#y * #y + 1)) $ let_ #b1 (#i + #j) $ let_ #b2 (#i - #j) $ let_ #b3 (#i * #j) $ let_ #b4 (#i `idiv` (#j * #j + 1)) $ #a1 + #a2 + #a3 + #a4 + toFloat_ (#b1 + #b2 + #b3 + #b4) ,adTest "order-of-operations" $ fromNamed $ body $ toFloat_ (3 * (3 `idiv` 2)) -- Compile had a pretty-printing bug at some point ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x) ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x ,adTest "pairs" term_pairs ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx ,adTest "build1-sum" term_build1_sum ,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ maximum1i #x ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ minimum1i #x ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42 ,adTestTp "sparse" (C "" 5) term_sparse ,adTestGen "neural" Example.neural genNeural ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural ,adTestTp "logsumexp" (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM ,adTestGen "gmm" (Example.gmmObjective False) genGMM ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM ] where genGMM = do -- The input ranges here are completely arbitrary. let tR = STScal STF64 kN <- Gen.integral (Range.linear 1 10) kD <- Gen.integral (Range.linear 1 10) kK <- Gen.integral (Range.linear 1 10) let i2i64 = fromIntegral @Int @Int64 valpha <- genArray tR (ShNil `ShCons` kK) vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma k3 = 0.42 -- don't feel like multigammaing today return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil) genNeural = do let tR = STScal STF64 let genLayer nin nout = liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) <*> genArray tR (ShNil `ShCons` nout) nin <- Gen.integral (Range.linear 1 10) n1 <- Gen.integral (Range.linear 1 10) n2 <- Gen.integral (Range.linear 1 10) input <- genArray tR (ShNil `ShCons` nin) lay1 <- genLayer nin n1 lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) main :: IO () main = defaultMain tests