{-# LANGUAGE DataKinds #-} -- {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -- {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Main where import Data.Bifunctor -- import qualified Data.Dependent.Map as DMap -- import Data.Dependent.Map (DMap) import Data.Foldable (toList) import Data.List (intercalate, intersperse) import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Hedgehog.Main import Array import AST import CHAD import CHAD.Types import Data import qualified Example import ForwardAD import Interpreter import Interpreter.Rep import Language type family MapMerge env where MapMerge '[] = '[] MapMerge (t : ts) = "merge" : MapMerge ts mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] mapMergeNoAccum SNil = Refl mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env mapMergeOnlyMerge SNil = Refl mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) gradientByCHAD = \env term input -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> let descr = makeMergeDescr env dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) input1 = toPrimalE env input (_out, grad) = interpretOpen input1 dterm in unTup vUnpair (d2e env) (Value grad) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') toPrimalE SNil SNil = SNil toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp toPrimal :: STy t -> Rep t -> Rep (D1 t) toPrimal = \case STNil -> id STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) STMaybe t -> fmap (toPrimal t) STArr _ t -> fmap (toPrimal t) STScal _ -> id STAccum{} -> error "Accumulators not allowed in input program" gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = Value (toTan t p x) `SCons` toTanE env primal inp toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der STPair t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal STEither t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der STArr _ t | shapeSize (arrayShape der) == 0 -> arrayMap (zeroTan t) primal | arrayShape primal == arrayShape der -> arrayGenerateLin (arrayShape primal) $ \i -> toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) | otherwise -> error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 closeIsh :: Double -> Double -> Bool closeIsh a b = abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) genShape :: SNat n -> Gen (Shape n) genShape = \n -> do sh <- genShapeNaive n let sz = shapeSize sh factor = sz `div` 100 + 1 return (shapeDiv sh factor) where genShapeNaive :: SNat n -> Gen (Shape n) genShapeNaive SZ = return ShNil genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10) shapeDiv :: Shape n -> Int -> Shape n shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) genValue :: STy a -> Gen (Value a) genValue = \case STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a <*> genValue b STEither a b -> Gen.choice [liftV Left <$> genValue a ,liftV Right <$> genValue b] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t] STArr n t -> genShape n >>= genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env -- data TemplateVar n = TemplateVar (SNat n) String -- deriving (Show) -- data Template t where -- TpShape :: TemplateVar n -> STy t -> Template (TArr n t) -- TpAny :: STy t -> Template t -- TpPair :: Template a -> Template b -> Template (TPair a b) -- deriving instance Show (Template t) -- data ShapeConstraint n = ShapeAtLeast (Shape n) -- deriving (Show) -- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) -- genTemplate = _ -- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) -- genEnvTemplateExact shapes env = _ -- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) -- genEnvTemplate constrs env = do -- shapes <- DMap.traverseWithKey _ constrs -- genEnvTemplateExact shapes env showValue :: Int -> STy t -> Rep t -> ShowS showValue _ STNil () = showString "()" showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x showValue d (STArr _ t) arr = showParen (d > 10) $ showString "arrayFromList " . showsPrec 11 (arrayShape arr) . showString " [" . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) . showString "]" showValue _ (STScal sty) x = case sty of STF32 -> shows x STF64 -> shows x STI32 -> shows x STI64 -> shows x STBool -> shows x showValue _ STAccum{} _ = error "Cannot show accumulators" showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" where showEntries :: SList STy env -> SList Value env -> [String] showEntries SNil SNil = [] showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property adTest = flip adTestGen (genEnv (knownEnv @env)) -- adTestTp :: forall env. KnownEnv env -- => DMap TemplateVar ShapeConstraint -> SList Template env -- -> Ex env (TScal TF64) -> Property -- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property adTestGen expr envGenerator = property $ do let env = knownEnv @env input <- forAllWith (showEnv env) envGenerator let gradFwd = gradientByForward knownEnv expr input gradCHAD = gradientByCHAD' knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs tests :: IO Bool tests = checkParallel $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x) ,("pairs", adTest $ fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p) ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ -- idx0 $ sum1i . sum1i $ -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) -- ,("neural", adTestGen Example.neural $ do -- let tR = STScal STF64 -- let genLayer nin nout = -- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) -- <*> genArray tR (ShNil `ShCons` nout) -- nin <- Gen.integral (Range.linear 1 10) -- n1 <- Gen.integral (Range.linear 1 10) -- n2 <- Gen.integral (Range.linear 1 10) -- input <- genArray tR (ShNil `ShCons` nin) -- lay1 <- genLayer nin n1 -- lay2 <- genLayer n1 n2 -- lay3 <- genArray tR (ShNil `ShCons` n2) -- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) ] main :: IO () main = defaultMain [tests]