diff options
Diffstat (limited to 'AD.hs')
-rw-r--r-- | AD.hs | 152 |
1 files changed, 152 insertions, 0 deletions
@@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AD ( + Dual, + dual, + ad, +) where + +import Data.Bifunctor +import Data.Type.Equality + +import AST +import qualified Language as L +import Sink + + +-- | Dual-number version of a type. Pairs up every Double with a tangent copy. +type family Dual a where + Dual () = () + Dual Int = Int + Dual Bool = Bool + Dual Double = (Double, Double) + Dual (a, b) = (Dual a, Dual b) + Dual (Array sh a) = Array sh (Dual a) + Dual (a -> b) = Dual a -> Dual b + +data DEnv env env' where + ETop :: DEnv env env + ECons :: Type a -> DEnv env env' -> DEnv (a ': env) (Dual a ': env') + +-- | Convert a type to its 'Dual' variant. +dual :: Type a -> Type (Dual a) +dual (TFun a b) = TFun (dual a) (dual b) +dual TInt = TInt +dual TBool = TBool +dual TDouble = TPair TDouble TDouble +dual (TArray sht t) = TArray sht (dual t) +dual TNil = TNil +dual (TPair a b) = TPair (dual a) (dual b) + +-- | Forward AD. +-- +-- Since @'Dual' (a -> b) = 'Dual' a -> 'Dual' b@, calling 'ad' on a +-- function-typed expression gives the most useful results. +ad :: Exp env a -> Exp env (Dual a) +ad = ad' ETop + +ad' :: DEnv env env' -> Exp env a -> Exp env' (Dual a) +ad' env = \case + App e1 e2 -> App (ad' env e1) (ad' env e2) + Lam t e -> Lam (dual t) (ad' (ECons t env) e) + Var t i -> + case convIdx env i of + Left freei -> App (duale t) (Var t freei) + Right duali -> Var (dual t) duali + Let e1 e2 -> Let (ad' env e1) (ad' (ECons (typeof e1) env) e2) + Lit l -> App (duale (literalType l)) (Lit l) + Cond e1 e2 e3 -> Cond (ad' env e1) (ad' env e2) (ad' env e3) + Const CAddI -> Const CAddI + Const CSubI -> Const CSubI + Const CMulI -> Const CMulI + Const CDivI -> Const CDivI + Const CAddF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CAddF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CAddF) (Pair (Snd (Fst v)) (Snd (Snd v))))) + Const CSubF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CSubF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CSubF) (Pair (Snd (Fst v)) (Snd (Snd v))))) + Const CMulF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CAddF) (Pair + (App (Const CMulF) (Pair (Fst (Fst v)) (Snd (Snd v)))) + (App (Const CMulF) (Pair (Fst (Snd v)) (Snd (Fst v))))))) + Const _ -> undefined + Pair e1 e2 -> Pair (ad' env e1) (ad' env e2) + Fst e -> Fst (ad' env e) + Snd e -> Snd (ad' env e) + Build sht e1 e2 + | Refl <- prfDualSht sht + -> Build sht (ad' env e1) (ad' env e2) + Ifold sht e1 e2 e3 + | Refl <- prfDualSht sht + -> Ifold sht (ad' env e1) (ad' env e2) (ad' env e3) + Index e1 e2 + | TArray sht _ <- typeof e1 + , Refl <- prfDualSht sht + -> Index (ad' env e1) (ad' env e2) + Shape e + | TArray sht _ <- typeof e + , Refl <- prfDualSht sht + -> Shape (ad' env e) + +convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a)) +convIdx ETop i = Left i +convIdx (ECons _ _) Zero = Right Zero +convIdx (ECons _ env) (Succ i) = bimap Succ Succ (convIdx env i) + +duale :: Type a -> Exp env (a -> Dual a) +duale topty = Lam topty (go topty) + where + go :: forall a env. Type a -> Exp (a ': env) (Dual a) + go ty = case ty of + TInt -> ref + TBool -> ref + TDouble -> Pair ref (Lit (LDouble 0)) + TArray _ t -> L.map (Lam t (go t)) ref + TNil -> Lit LNil + TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) + TFun t1 t2 -> Lam (dual t1) + (App (duale t2) + (App (sinkExp1 ref) + (App (unduale t1) + (Var (dual t1) Zero)))) + where + ref :: Exp (a ': env) a + ref = Var ty Zero + +unduale :: Type a -> Exp env (Dual a -> a) +unduale topty = Lam (dual topty) (go topty) + where + go :: forall a env. Type a -> Exp (Dual a ': env) a + go ty = case ty of + TInt -> ref + TBool -> ref + TDouble -> Fst ref + TArray _ t -> L.map (Lam (dual t) (go t)) ref + TNil -> Lit LNil + TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) + TFun t1 t2 -> Lam t1 + (App (unduale t2) + (App (sinkExp1 ref) + (App (duale t1) + (Var t1 Zero)))) + where + ref :: Exp (Dual a ': env) (Dual a) + ref = Var (dual ty) Zero + +prfDualSht :: ShapeType sh -> sh :~: Dual sh +prfDualSht STZ = Refl +prfDualSht (STC sht) + | Refl <- prfDualSht sht + = Refl |