{-# LANGUAGE DataKinds #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module ForwardAD where import Data.Bifunctor (bimap) -- import Debug.Trace -- import AST.Pretty import Array import AST import Data import ForwardAD.DualNumbers import Interpreter import Interpreter.Rep -- | Tangent along a type (coincides with cotangent for these types) type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) Tan (TEither a b) = TEither (Tan a) (Tan b) Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t type family TanS t where TanS TI32 = TNil TanS TI64 = TNil TanS TF32 = TScal TF32 TanS TF64 = TScal TF64 TanS TBool = TNil type family TanE env where TanE '[] = '[] TanE (t : env) = Tan t : TanE env tanty :: STy t -> STy (Tan t) tanty STNil = STNil tanty (STPair a b) = STPair (tanty a) (tanty b) tanty (STEither a b) = STEither (tanty a) (tanty b) tanty (STMaybe t) = STMaybe (tanty t) tanty (STArr n t) = STArr n (tanty t) tanty (STScal t) = case t of STI32 -> STNil STI64 -> STNil STF32 -> STScal STF32 STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" zeroTan :: STy t -> Rep t -> Rep (Tan t) zeroTan STNil () = () zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) zeroTan (STEither a _) (Left x) = Left (zeroTan a x) zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) zeroTan (STMaybe _) Nothing = Nothing zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) zeroTan (STArr _ t) x = fmap (zeroTan t) x zeroTan (STScal STI32) _ = () zeroTan (STScal STI64) _ = () zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y tanScalars (STEither a _) (Left x) = tanScalars a x tanScalars (STEither _ b) (Right y) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x tanScalars (STScal STI32) _ = [] tanScalars (STScal STI64) _ = [] tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) unzipDN STNil _ = ((), ()) unzipDN (STPair a b) (d1, d2) = let (x, dx) = unzipDN a d1 (y, dy) = unzipDN b d2 in ((x, y), (dx, dy)) unzipDN (STEither a b) d = case d of Left d1 -> bimap Left Left (unzipDN a d1) Right d2 -> bimap Right Right (unzipDN b d2) unzipDN (STMaybe t) d = case d of Nothing -> (Nothing, Nothing) Just d' -> bimap Just Just (unzipDN t d') unzipDN (STArr _ t) d = let pairs = arrayMap (unzipDN t) d in (arrayMap fst pairs, arrayMap snd pairs) unzipDN (STScal ty) d = case ty of STI32 -> (d, ()) STI64 -> (d, ()) STF32 -> d STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 dotprodTan (STPair a b) (x, y) (x', y') = dotprodTan a x x' + dotprodTan b y y' dotprodTan (STEither a b) x y = case (x, y) of (Left x', Left y') -> dotprodTan a x' y' (Right x', Right y') -> dotprodTan b x' y' _ -> error "dotprodTan: incompatible Either alternatives" dotprodTan (STMaybe t) x y = case (x, y) of (Nothing, Nothing) -> 0.0 (Just x', Just y') -> dotprodTan t x' y' _ -> error "dotprodTan: incompatible Maybe alternatives" dotprodTan (STArr _ t) x y = let sh1 = arrayShape x sh2 = arrayShape y in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] | otherwise -> error "dotprodTan: incompatible array shapes" dotprodTan (STScal ty) x y = case ty of STI32 -> 0.0 STI64 -> 0.0 STF32 -> realToFrac @Float @Double (x * y) STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) -- dnConstE STNil _ = ENil ext -- dnConstE (STPair t1 t2) e = -- -- This creates fst/snd stacks of unbounded size, but let's not care here -- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) -- dnConstE (STEither t1 t2) e = -- ECase ext e -- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) -- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) -- dnConstE (STMaybe t) e = -- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e -- dnConstE (STArr n t) e = -- EBuild ext n (EShape ext e) -- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) -- dnConstE (STScal t) e = case t of -- STI32 -> e -- STI64 -> e -- STF32 -> EPair ext e (EConst ext STF32 0.0) -- STF64 -> EPair ext e (EConst ext STF64 0.0) -- STBool -> e -- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" dnConst :: STy t -> Rep t -> Rep (DN t) dnConst STNil = const () dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STMaybe t) = fmap (dnConst t) dnConst (STArr _ t) = arrayMap (dnConst t) dnConst (STScal t) = case t of STI32 -> id STI64 -> id STF32 -> (,0.0) STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this -- @t@ input. type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) dnOnehots :: STy t -> Rep t -> RevByFwd t dnOnehots STNil _ = \_ -> () dnOnehots (STPair t1 t2) (x, y) = \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) dnOnehots (STEither t1 t2) e = case e of Left x -> \f -> Left (dnOnehots t1 x (f . Left)) Right y -> \f -> Right (dnOnehots t2 y (f . Right)) dnOnehots (STMaybe t) m = case m of Nothing -> \_ -> Nothing Just x -> \f -> Just (dnOnehots t x (f . Just)) dnOnehots (STArr _ t) a = \f -> arrayGenerate (arrayShape a) $ \idx -> dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> if i == idx then oh else dnConst t (arrayIndex a i))) dnOnehots (STScal t) x = case t of STI32 -> \_ -> () STI64 -> \_ -> () STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env dnOnehotEnvs SNil SNil = \_ -> SNil dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = \f -> Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwd env expr input dres = let outty = typeOf expr in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ dnOnehotEnvs env input $ \dnInput -> -- trace (showEnv (dne env) dnInput) $ let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) in dotprodTan outty outtan dres