{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Main where import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Hedgehog.Main import Array import AST import AST.Pretty import CHAD.Top import CHAD.Types import qualified Example import qualified Example.GMM as Example import ForwardAD import Interpreter import Interpreter.Rep import Language import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term dterm | Dict <- envKnown env = case simplIters of SimplIters n -> simplifyN n dtermNonSimpl SimplFix -> simplifyFix dtermNonSimpl (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = Value (toTan t p x) `SCons` toTanE env primal inp toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der STPair t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal STEither t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der STArr _ t | shapeSize (arrayShape der) == 0 -> arrayMap (zeroTan t) primal | arrayShape primal == arrayShape der -> arrayGenerateLin (arrayShape primal) $ \i -> toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) | otherwise -> error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 closeIsh :: Double -> Double -> Bool closeIsh a b = abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) data a :$ b = a :$ b deriving (Show) ; infixl :$ -- An empty name means "no restrictions". data TplConstr = C String -- ^ name; @""@ means anonymous Int -- ^ minimum value to generate type family DimNames n where DimNames Z = () DimNames (S Z) = TplConstr DimNames (S n) = DimNames n :$ TplConstr type family Tpl t where Tpl (TArr n t) = DimNames n Tpl (TPair a b) = (Tpl a, Tpl b) -- If you add equations here, don't forget to update genValue! It currently -- just emptyTpl's things out. Tpl _ = () data a :& b = a :& b deriving (Show) ; infixl :& type family TemplateE env where TemplateE '[] = () TemplateE '[t] = Tpl t TemplateE (t : ts) = TemplateE ts :& Tpl t emptyDimNames :: SNat n -> DimNames n emptyDimNames SZ = () emptyDimNames (SS SZ) = C "" 0 emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 emptyTpl :: STy t -> Tpl t emptyTpl (STArr n _) = emptyDimNames n emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) emptyTpl (STScal _) = () emptyTpl _ = error "too lazy" emptyTemplateE :: SList STy env -> TemplateE env emptyTemplateE SNil = () emptyTemplateE (t `SCons` SNil) = emptyTpl t emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShape = \n tpl -> do sh <- genShapeNaive n tpl let sz = shapeSize sh factor = sz `div` 100 + 1 return (shapeDiv sh factor) where genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShapeNaive SZ () = return ShNil genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int genNamedDim (C "" lo) = genDim lo genNamedDim (C name lo) = gets (Map.lookup name) >>= \case Nothing -> do dim <- genDim lo modify (Map.insert name dim) return dim Just dim -> return dim genDim :: Int -> StateT (Map String Int) Gen Int genDim lo = Gen.integral (Range.linear lo 10) shapeDiv :: Shape n -> Int -> Shape n shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) genValue topty tpl = case topty of STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) ,liftV Right <$> genValue b (emptyTpl b)] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property adTest = adTestCon (const True) adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property adTestCon constr term = let env = knownEnv in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) adTestTp :: forall env. KnownEnv env => TemplateE env -> Ex env (TScal TF64) -> Property adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property adTestGen expr envGenerator = property $ do let env = knownEnv @env input <- forAllWith (showEnv env) envGenerator let outPrimal = interpretOpen False input expr gradFwd = gradientByForward knownEnv expr input (_ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) -- annotate (ppExpr knownEnv expr) -- annotate ppdterm -- annotate ppdterm_S diff ppdterm_S (==) ppdterm_S20 diff outChad_S closeIsh outChad diff outChad_S closeIsh outPrimal diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p tests :: IO Bool tests = checkParallel $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x) ,("pairs", adTest term_pairs) ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0) ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) ,("build1-sum", adTest term_build1_sum) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ maximum1i #x) ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ minimum1i #x) ,("unused", adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42) ,("neural", adTestGen Example.neural $ do let tR = STScal STF64 let genLayer nin nout = liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) <*> genArray tR (ShNil `ShCons` nout) nin <- Gen.integral (Range.linear 1 10) n1 <- Gen.integral (Range.linear 1 10) n2 <- Gen.integral (Range.linear 1 10) input <- genArray tR (ShNil `ShCons` nin) lay1 <- genLayer nin n1 lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) ,("logsumexp", adTestTp (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m) ,("mulmatvec", adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $ fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ let_ #wid (snd_ (shape #mat)) $ build1 #hei $ #i :-> idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))) ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) ] where genGMM = do -- The input ranges here are completely arbitrary. let tR = STScal STF64 kN <- Gen.integral (Range.linear 1 10) kD <- Gen.integral (Range.linear 1 10) kK <- Gen.integral (Range.linear 1 10) let i2i64 = fromIntegral @Int @Int64 valpha <- genArray tR (ShNil `ShCons` kK) vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma k3 = 0.42 -- don't feel like multigammaing today return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil) main :: IO () main = defaultMain [tests]