{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} module Main where import Data.Bifunctor import Hedgehog import Hedgehog.Main import Array import AST import CHAD import CHAD.Types import Data import ForwardAD import Interpreter import Interpreter.Rep type family MapMerge env where MapMerge '[] = '[] MapMerge (t : ts) = "merge" : MapMerge ts mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] mapMergeNoAccum SNil = Refl mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env mapMergeOnlyMerge SNil = Refl mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) gradientByCHAD = \env term input -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> let descr = makeMergeDescr env dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) input1 = toPrimalE env input (_out, grad) = interpretOpen input1 dterm in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') toPrimalE SNil SNil = SNil toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp toPrimal :: STy t -> Rep t -> Rep (D1 t) toPrimal = \case STNil -> id STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) STMaybe t -> fmap (toPrimal t) STArr _ t -> fmap (toPrimal t) STScal _ -> id STAccum{} -> error "Accumulators not allowed in input program" gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = Value (toTan t p x) `SCons` toTanE env primal inp toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der STPair t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal STEither t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der STArr _ t | shapeSize (arrayShape der) == 0 -> arrayMap (zeroTan t) primal | arrayShape primal == arrayShape der -> arrayGenerateLin (arrayShape primal) $ \i -> toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) | otherwise -> error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 closeIsh :: Double -> Double -> Bool closeIsh a b = abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5) adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property adTest input expr = property $ let env = knownEnv @env gradFwd = gradientByForward knownEnv expr input gradCHAD = gradientByCHAD' knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs tests :: IO Bool tests = checkParallel $ Group "AD" [("id", adTest (Value 42.0))] main :: IO () main = defaultMain [tests]