{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Main where import Data.Bifunctor import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Hedgehog.Main import Array import AST import AST.Pretty import CHAD.Top import CHAD.Types import qualified Example import ForwardAD import Interpreter import Interpreter.Rep import Language import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term dterm | Dict <- envKnown env = case simplIters of SimplIters n -> simplifyN n dtermNonSimpl SimplFix -> simplifyFix dtermNonSimpl (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = Value (toTan t p x) `SCons` toTanE env primal inp toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der STPair t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal STEither t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der STArr _ t | shapeSize (arrayShape der) == 0 -> arrayMap (zeroTan t) primal | arrayShape primal == arrayShape der -> arrayGenerateLin (arrayShape primal) $ \i -> toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) | otherwise -> error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 closeIsh :: Double -> Double -> Bool closeIsh a b = abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) genShape :: SNat n -> Gen (Shape n) genShape = \n -> do sh <- genShapeNaive n let sz = shapeSize sh factor = sz `div` 100 + 1 return (shapeDiv sh factor) where genShapeNaive :: SNat n -> Gen (Shape n) genShapeNaive SZ = return ShNil genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10) shapeDiv :: Shape n -> Int -> Shape n shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) genValue :: STy a -> Gen (Value a) genValue = \case STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a <*> genValue b STEither a b -> Gen.choice [liftV Left <$> genValue a ,liftV Right <$> genValue b] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t] STArr n t -> genShape n >>= genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" where showEntries :: SList STy env -> SList Value env -> [String] showEntries SNil SNil = [] showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property adTest = adTestCon (const True) adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property adTestGen expr envGenerator = property $ do let env = knownEnv @env input <- forAllWith (showEnv env) envGenerator let outPrimal = interpretOpen False input expr gradFwd = gradientByForward knownEnv expr input (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) annotate (ppExpr knownEnv expr) annotate ppdterm annotate ppdterm_S diff ppdterm_S20 (==) ppdterm_S diff outChad closeIsh outChad_S diff outPrimal closeIsh outChad_S diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p tests :: IO Bool tests = checkSequential $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x) ,("pairs", adTest term_pairs) ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0) ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) ,("build1-sum", adTest term_build1_sum) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ maximum1i #x) ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ minimum1i #x) ,("neural", adTestGen Example.neural $ do let tR = STScal STF64 let genLayer nin nout = liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) <*> genArray tR (ShNil `ShCons` nout) nin <- Gen.integral (Range.linear 1 10) n1 <- Gen.integral (Range.linear 1 10) n2 <- Gen.integral (Range.linear 1 10) input <- genArray tR (ShNil `ShCons` nin) lay1 <- genLayer nin n1 lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) ] main :: IO () main = defaultMain [tests]