{-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE LambdaCase #-} module Main where import Data.Bifunctor import Array import AST import CHAD import CHAD.Types import Data import ForwardAD import Interpreter import Interpreter.Rep type family MapMerge env where MapMerge '[] = '[] MapMerge (t : ts) = "merge" : MapMerge ts mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] mapMergeNoAccum SNil = Refl mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env mapMergeOnlyMerge SNil = Refl mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) gradientByCHAD = \env term input -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> let descr = makeMergeDescr env dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) input1 = toPrimalE env input (_out, grad) = interpretOpen input1 dterm in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') toPrimalE SNil SNil = SNil toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp toPrimal :: STy t -> Rep t -> Rep (D1 t) toPrimal = \case STNil -> id STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) STMaybe t -> fmap (toPrimal t) STArr _ t -> fmap (toPrimal t) STScal _ -> id STAccum{} -> error "Accumulators not allowed in input program" gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = Value (toTan t p x) `SCons` toTanE env primal inp toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der STPair t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal STEither t1 t2 -> case der of Left () -> bimap (zeroTan t1) (zeroTan t2) primal Right d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der STArr _ t | shapeSize (arrayShape der) == 0 -> arrayMap (zeroTan t) primal | arrayShape primal == arrayShape der -> arrayGenerateLin (arrayShape primal) $ \i -> toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) | otherwise -> error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 main :: IO () main = return ()