{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module AD ( Dual, dual, ad, ) where import Data.Bifunctor import Data.Type.Equality import AST import qualified Language as L import Sink -- | Dual-number version of a type. Pairs up every Double with a tangent copy. type family Dual a where Dual () = () Dual Int = Int Dual Bool = Bool Dual Double = (Double, Double) Dual (a, b) = (Dual a, Dual b) Dual (Array sh a) = Array sh (Dual a) Dual (a -> b) = Dual a -> Dual b data DEnv env env' where ETop :: DEnv env env ECons :: Type a -> DEnv env env' -> DEnv (a ': env) (Dual a ': env') -- | Convert a type to its 'Dual' variant. dual :: Type a -> Type (Dual a) dual (TFun a b) = TFun (dual a) (dual b) dual TInt = TInt dual TBool = TBool dual TDouble = TPair TDouble TDouble dual (TArray sht t) = TArray sht (dual t) dual TNil = TNil dual (TPair a b) = TPair (dual a) (dual b) -- | Forward AD. -- -- Since @'Dual' (a -> b) = 'Dual' a -> 'Dual' b@, calling 'ad' on a -- function-typed expression gives the most useful results. ad :: Exp env a -> Exp env (Dual a) ad = ad' ETop ad' :: DEnv env env' -> Exp env a -> Exp env' (Dual a) ad' env = \case App e1 e2 -> App (ad' env e1) (ad' env e2) Lam t e -> Lam (dual t) (ad' (ECons t env) e) Var t i -> case convIdx env i of Left freei -> App (duale t) (Var t freei) Right duali -> Var (dual t) duali Let e1 e2 -> Let (ad' env e1) (ad' (ECons (typeof e1) env) e2) Lit l -> App (duale (literalType l)) (Lit l) Cond e1 e2 e3 -> Cond (ad' env e1) (ad' env e2) (ad' env e3) Const CAddI -> Const CAddI Const CSubI -> Const CSubI Const CMulI -> Const CMulI Const CDivI -> Const CDivI Const CAddF -> let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) (Pair (App (Const CAddF) (Pair (Fst (Fst v)) (Fst (Snd v)))) (App (Const CAddF) (Pair (Snd (Fst v)) (Snd (Snd v))))) Const CSubF -> let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) (Pair (App (Const CSubF) (Pair (Fst (Fst v)) (Fst (Snd v)))) (App (Const CSubF) (Pair (Snd (Fst v)) (Snd (Snd v))))) Const CMulF -> let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Fst (Snd v)))) (App (Const CAddF) (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Snd (Snd v)))) (App (Const CMulF) (Pair (Fst (Snd v)) (Snd (Fst v))))))) Const _ -> undefined Pair e1 e2 -> Pair (ad' env e1) (ad' env e2) Fst e -> Fst (ad' env e) Snd e -> Snd (ad' env e) Build sht e1 e2 | Refl <- prfDualSht sht -> Build sht (ad' env e1) (ad' env e2) Ifold sht e1 e2 e3 | Refl <- prfDualSht sht -> Ifold sht (ad' env e1) (ad' env e2) (ad' env e3) Index e1 e2 | TArray sht _ <- typeof e1 , Refl <- prfDualSht sht -> Index (ad' env e1) (ad' env e2) Shape e | TArray sht _ <- typeof e , Refl <- prfDualSht sht -> Shape (ad' env e) convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a)) convIdx ETop i = Left i convIdx (ECons _ _) Zero = Right Zero convIdx (ECons _ env) (Succ i) = bimap Succ Succ (convIdx env i) duale :: Type a -> Exp env (a -> Dual a) duale topty = Lam topty (go topty) where go :: forall a env. Type a -> Exp (a ': env) (Dual a) go ty = case ty of TInt -> ref TBool -> ref TDouble -> Pair ref (Lit (LDouble 0)) TArray _ t -> L.map (Lam t (go t)) ref TNil -> Lit LNil TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) TFun t1 t2 -> Lam (dual t1) (App (duale t2) (App (sinkExp1 ref) (App (unduale t1) (Var (dual t1) Zero)))) where ref :: Exp (a ': env) a ref = Var ty Zero unduale :: Type a -> Exp env (Dual a -> a) unduale topty = Lam (dual topty) (go topty) where go :: forall a env. Type a -> Exp (Dual a ': env) a go ty = case ty of TInt -> ref TBool -> ref TDouble -> Fst ref TArray _ t -> L.map (Lam (dual t) (go t)) ref TNil -> Lit LNil TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) TFun t1 t2 -> Lam t1 (App (unduale t2) (App (sinkExp1 ref) (App (duale t1) (Var t1 Zero)))) where ref :: Exp (Dual a ': env) (Dual a) ref = Var (dual ty) Zero prfDualSht :: ShapeType sh -> sh :~: Dual sh prfDualSht STZ = Refl prfDualSht (STC sht) | Refl <- prfDualSht sht = Refl