diff options
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r-- | src/ForwardAD.hs | 377 |
1 files changed, 182 insertions, 195 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 0a9e12c..63244a8 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -1,202 +1,189 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +module ForwardAD where --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD ( - dfwd, - FD, FDS, FDE, fd, -) where +import Data.Bifunctor (bimap) +-- import Data.Foldable (toList) +import Array import AST +-- import AST.Bindings import Data - - --- | Dual-numbers transformation -type family FD t where - FD TNil = TNil - FD (TPair a b) = TPair (FD a) (FD b) - FD (TEither a b) = TEither (FD a) (FD b) - FD (TMaybe t) = TMaybe (FD t) - FD (TArr n t) = TArr n (FD t) - FD (TScal t) = FDS t - -type family FDS t where - FDS TF32 = TPair (TScal TF32) (TScal TF32) - FDS TF64 = TPair (TScal TF64) (TScal TF64) - FDS TI32 = TScal TI32 - FDS TI64 = TScal TI64 - FDS TBool = TScal TBool - -type family FDE env where - FDE '[] = '[] - FDE (t : ts) = FD t : FDE ts - -fd :: STy t -> STy (FD t) -fd STNil = STNil -fd (STPair a b) = STPair (fd a) (fd b) -fd (STEither a b) = STEither (fd a) (fd b) -fd (STMaybe t) = STMaybe (fd t) -fd (STArr n t) = STArr n (fd t) -fd (STScal t) = case t of - STF32 -> STPair (STScal STF32) (STScal STF32) - STF64 -> STPair (STScal STF64) (STScal STF64) - STI32 -> STScal STI32 - STI64 -> STScal STI64 - STBool -> STScal STBool -fd STAccum{} = error "Accum in source program" - -fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) -fdPreservesTupIx SZ = Refl -fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl - -convIdx :: Idx env t -> Idx (FDE env) (FD t) -convIdx IZ = IZ -convIdx (IS i) = IS (convIdx i) - -scalTyCase :: SScalTy t - -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r) - -> (FD (TScal t) ~ TScal t => r) - -> r -scalTyCase STF32 k1 _ = k1 -scalTyCase STF64 k1 _ = k1 -scalTyCase STI32 _ k2 = k2 -scalTyCase STI64 _ k2 = k2 -scalTyCase STBool _ k2 = k2 - --- | Argument does not need to be duplicable. -dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b) -dop = \case - OAdd t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) - (EOp ext (OAdd t)) - OMul t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) - (EOp ext (OMul t)) - ONeg t -> scalTyCase t - (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) - (EOp ext (ONeg t)) - OLt t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) - (EOp ext (OLt t)) - OLe t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) - (EOp ext (OLe t)) - OEq t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) - (EOp ext (OEq t)) - ONot -> EOp ext ONot - OIf -> EOp ext OIf - where - add :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - add t a b = EOp ext (OAdd t) (EPair ext a b) - - mul :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - mul t a b = EOp ext (OMul t) (EPair ext a b) - - neg :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) - neg t = EOp ext (ONeg t) - - unFloat :: FD a ~ TPair a a - => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b)) - -> Ex env (FD a) -> Ex env (FD b) - unFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext var, ESnd ext var) - - binFloat :: (a ~ TPair s s, FD s ~ TPair s s) - => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b)) - -> Ex env (FD a) -> Ex env (FD b) - binFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) - (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) - -dfwd :: Ex env t -> Ex (FDE env) (FD t) -dfwd = \case - EVar _ t i -> EVar ext (fd t) (convIdx i) - ELet _ a b -> ELet ext (dfwd a) (dfwd b) - EPair _ a b -> EPair ext (dfwd a) (dfwd b) - EFst _ e -> EFst ext (dfwd e) - ESnd _ e -> ESnd ext (dfwd e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext (fd t) (dfwd e) - EInr _ t e -> EInr ext (fd t) (dfwd e) - ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b) - ENothing _ t -> ENothing ext (fd t) - EJust _ e -> EJust ext (dfwd e) - EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b) - EConstArr _ n t x -> scalTyCase t - (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) - (EConstArr ext n t x)) - (EConstArr ext n t x) - EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b) - EBuild _ n a b - | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b) - EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b) - ESum1Inner _ e -> - let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) - in scalTyCase t - (ELet ext (dfwd e) $ - ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ))) - (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ)))) - (ESum1Inner ext (dfwd e)) - EUnit _ e -> EUnit ext (dfwd e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b) - EConst _ t x -> scalTyCase t - (EPair ext (EConst ext t x) (EConst ext t 0.0)) - (EConst ext t x) - EIdx0 _ e -> EIdx0 ext (dfwd e) - EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b) - EIdx _ n a b - | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b) - EShape _ e - | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e) - EOp _ op e -> dop op (dfwd e) - EError t s -> EError (fd t) s - - EWith{} -> err_accum - EAccum{} -> err_accum - EZero{} -> err_monoid - EPlus{} -> err_monoid - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip a b = - let STArr n t1 = typeOf a - STArr _ t2 = typeOf b - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) +import ForwardAD.DualNumbers +import Interpreter +import Interpreter.Rep + + +-- | Tangent along a type (coincides with cotangent for these types) +type family Tan t where + Tan TNil = TNil + Tan (TPair a b) = TPair (Tan a) (Tan b) + Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TMaybe t) = TMaybe (Tan t) + Tan (TArr n t) = TArr n (Tan t) + Tan (TScal t) = TanS t + +type family TanS t where + TanS TI32 = TNil + TanS TI64 = TNil + TanS TF32 = TScal TF32 + TanS TF64 = TScal TF64 + TanS TBool = TNil + +type family TanE env where + TanE '[] = '[] + TanE (t : env) = Tan t : TanE env + +tanty :: STy t -> STy (Tan t) +tanty STNil = STNil +tanty (STPair a b) = STPair (tanty a) (tanty b) +tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STMaybe t) = STMaybe (tanty t) +tanty (STArr n t) = STArr n (tanty t) +tanty (STScal t) = case t of + STI32 -> STNil + STI64 -> STNil + STF32 -> STScal STF32 + STF64 -> STScal STF64 + STBool -> STNil +tanty STAccum{} = error "Accumulators not allowed in input program" + +unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) +unzipDN STNil _ = ((), ()) +unzipDN (STPair a b) (d1, d2) = + let (x, dx) = unzipDN a d1 + (y, dy) = unzipDN b d2 + in ((x, y), (dx, dy)) +unzipDN (STEither a b) d = case d of + Left d1 -> bimap Left Left (unzipDN a d1) + Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STMaybe t) d = case d of + Nothing -> (Nothing, Nothing) + Just d' -> bimap Just Just (unzipDN t d') +unzipDN (STArr _ t) d = + let pairs = arrayMap (unzipDN t) d + in (arrayMap fst pairs, arrayMap snd pairs) +unzipDN (STScal ty) d = case ty of + STI32 -> (d, ()) + STI64 -> (d, ()) + STF32 -> d + STF64 -> d + STBool -> (d, ()) +unzipDN STAccum{} _ = error "Accumulators not allowed in input program" + +dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double +dotprodTan STNil _ _ = 0.0 +dotprodTan (STPair a b) (x, y) (x', y') = + dotprodTan a x x' + dotprodTan b y y' +dotprodTan (STEither a b) x y = case (x, y) of + (Left x', Left y') -> dotprodTan a x' y' + (Right x', Right y') -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STMaybe t) x y = case (x, y) of + (Nothing, Nothing) -> 0.0 + (Just x', Just y') -> dotprodTan t x' y' + _ -> error "dotprodTan: incompatible Maybe alternatives" +dotprodTan (STArr _ t) x y = + let sh1 = arrayShape x + sh2 = arrayShape y + in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 + | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] + | otherwise -> error "dotprodTan: incompatible array shapes" +dotprodTan (STScal ty) x y = case ty of + STI32 -> 0.0 + STI64 -> 0.0 + STF32 -> realToFrac @Float @Double (x * y) + STF64 -> x * y + STBool -> 0.0 +dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" + +-- -- Primal expression must be duplicable +-- dnConstE :: STy t -> Ex env t -> Ex env (DN t) +-- dnConstE STNil _ = ENil ext +-- dnConstE (STPair t1 t2) e = +-- -- This creates fst/snd stacks of unbounded size, but let's not care here +-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) +-- dnConstE (STEither t1 t2) e = +-- ECase ext e +-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) +-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) +-- dnConstE (STMaybe t) e = +-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e +-- dnConstE (STArr n t) e = +-- EBuild ext n (EShape ext e) +-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) +-- dnConstE (STScal t) e = case t of +-- STI32 -> e +-- STI64 -> e +-- STF32 -> EPair ext e (EConst ext STF32 0.0) +-- STF64 -> EPair ext e (EConst ext STF64 0.0) +-- STBool -> e +-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" + +dnConst :: STy t -> Rep t -> Rep (DN t) +dnConst STNil = const () +dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STMaybe t) = fmap (dnConst t) +dnConst (STArr _ t) = arrayMap (dnConst t) +dnConst (STScal t) = case t of + STI32 -> id + STI64 -> id + STF32 -> (,0.0) + STF64 -> (,0.0) + STBool -> id +dnConst STAccum{} = error "Accumulators not allowed in input program" + +-- | Given a function that computes the forward derivative for a particular +-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this +-- @t@ input. +type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) + +dnOnehots :: STy t -> Rep t -> RevByFwd t +dnOnehots STNil _ = \_ -> () +dnOnehots (STPair t1 t2) (x, y) = + \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) +dnOnehots (STEither t1 t2) e = + case e of + Left x -> \f -> Left (dnOnehots t1 x (f . Left)) + Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STMaybe t) m = + case m of + Nothing -> \_ -> Nothing + Just x -> \f -> Just (dnOnehots t x (f . Just)) +dnOnehots (STArr _ t) a = + \f -> + arrayGenerate (arrayShape a) $ \idx -> + dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> + if i == idx then oh else dnConst t (arrayIndex a i))) +dnOnehots (STScal t) x = case t of + STI32 -> \_ -> () + STI64 -> \_ -> () + STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) + STF64 -> \f -> f (x, 1.0) + STBool -> \_ -> () +dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" + +dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) +dnConstEnv SNil SNil = SNil +dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val + +type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) + +dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env +dnOnehotEnvs SNil SNil = \_ -> SNil +dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = + \f -> + Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) + `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) + +drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd env expr input dres = + let outty = typeOf expr + in dnOnehotEnvs env input $ \dnInput -> + let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr)) + in dotprodTan outty outtan dres |