aboutsummaryrefslogtreecommitdiff
path: root/AD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'AD.hs')
-rw-r--r--AD.hs152
1 files changed, 152 insertions, 0 deletions
diff --git a/AD.hs b/AD.hs
new file mode 100644
index 0000000..76fefe4
--- /dev/null
+++ b/AD.hs
@@ -0,0 +1,152 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module AD (
+ Dual,
+ dual,
+ ad,
+) where
+
+import Data.Bifunctor
+import Data.Type.Equality
+
+import AST
+import qualified Language as L
+import Sink
+
+
+-- | Dual-number version of a type. Pairs up every Double with a tangent copy.
+type family Dual a where
+ Dual () = ()
+ Dual Int = Int
+ Dual Bool = Bool
+ Dual Double = (Double, Double)
+ Dual (a, b) = (Dual a, Dual b)
+ Dual (Array sh a) = Array sh (Dual a)
+ Dual (a -> b) = Dual a -> Dual b
+
+data DEnv env env' where
+ ETop :: DEnv env env
+ ECons :: Type a -> DEnv env env' -> DEnv (a ': env) (Dual a ': env')
+
+-- | Convert a type to its 'Dual' variant.
+dual :: Type a -> Type (Dual a)
+dual (TFun a b) = TFun (dual a) (dual b)
+dual TInt = TInt
+dual TBool = TBool
+dual TDouble = TPair TDouble TDouble
+dual (TArray sht t) = TArray sht (dual t)
+dual TNil = TNil
+dual (TPair a b) = TPair (dual a) (dual b)
+
+-- | Forward AD.
+--
+-- Since @'Dual' (a -> b) = 'Dual' a -> 'Dual' b@, calling 'ad' on a
+-- function-typed expression gives the most useful results.
+ad :: Exp env a -> Exp env (Dual a)
+ad = ad' ETop
+
+ad' :: DEnv env env' -> Exp env a -> Exp env' (Dual a)
+ad' env = \case
+ App e1 e2 -> App (ad' env e1) (ad' env e2)
+ Lam t e -> Lam (dual t) (ad' (ECons t env) e)
+ Var t i ->
+ case convIdx env i of
+ Left freei -> App (duale t) (Var t freei)
+ Right duali -> Var (dual t) duali
+ Let e1 e2 -> Let (ad' env e1) (ad' (ECons (typeof e1) env) e2)
+ Lit l -> App (duale (literalType l)) (Lit l)
+ Cond e1 e2 e3 -> Cond (ad' env e1) (ad' env e2) (ad' env e3)
+ Const CAddI -> Const CAddI
+ Const CSubI -> Const CSubI
+ Const CMulI -> Const CMulI
+ Const CDivI -> Const CDivI
+ Const CAddF ->
+ let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero
+ in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble))
+ (Pair (App (Const CAddF) (Pair (Fst (Fst v)) (Fst (Snd v))))
+ (App (Const CAddF) (Pair (Snd (Fst v)) (Snd (Snd v)))))
+ Const CSubF ->
+ let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero
+ in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble))
+ (Pair (App (Const CSubF) (Pair (Fst (Fst v)) (Fst (Snd v))))
+ (App (Const CSubF) (Pair (Snd (Fst v)) (Snd (Snd v)))))
+ Const CMulF ->
+ let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero
+ in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble))
+ (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Fst (Snd v))))
+ (App (Const CAddF) (Pair
+ (App (Const CMulF) (Pair (Fst (Fst v)) (Snd (Snd v))))
+ (App (Const CMulF) (Pair (Fst (Snd v)) (Snd (Fst v)))))))
+ Const _ -> undefined
+ Pair e1 e2 -> Pair (ad' env e1) (ad' env e2)
+ Fst e -> Fst (ad' env e)
+ Snd e -> Snd (ad' env e)
+ Build sht e1 e2
+ | Refl <- prfDualSht sht
+ -> Build sht (ad' env e1) (ad' env e2)
+ Ifold sht e1 e2 e3
+ | Refl <- prfDualSht sht
+ -> Ifold sht (ad' env e1) (ad' env e2) (ad' env e3)
+ Index e1 e2
+ | TArray sht _ <- typeof e1
+ , Refl <- prfDualSht sht
+ -> Index (ad' env e1) (ad' env e2)
+ Shape e
+ | TArray sht _ <- typeof e
+ , Refl <- prfDualSht sht
+ -> Shape (ad' env e)
+
+convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a))
+convIdx ETop i = Left i
+convIdx (ECons _ _) Zero = Right Zero
+convIdx (ECons _ env) (Succ i) = bimap Succ Succ (convIdx env i)
+
+duale :: Type a -> Exp env (a -> Dual a)
+duale topty = Lam topty (go topty)
+ where
+ go :: forall a env. Type a -> Exp (a ': env) (Dual a)
+ go ty = case ty of
+ TInt -> ref
+ TBool -> ref
+ TDouble -> Pair ref (Lit (LDouble 0))
+ TArray _ t -> L.map (Lam t (go t)) ref
+ TNil -> Lit LNil
+ TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2))
+ TFun t1 t2 -> Lam (dual t1)
+ (App (duale t2)
+ (App (sinkExp1 ref)
+ (App (unduale t1)
+ (Var (dual t1) Zero))))
+ where
+ ref :: Exp (a ': env) a
+ ref = Var ty Zero
+
+unduale :: Type a -> Exp env (Dual a -> a)
+unduale topty = Lam (dual topty) (go topty)
+ where
+ go :: forall a env. Type a -> Exp (Dual a ': env) a
+ go ty = case ty of
+ TInt -> ref
+ TBool -> ref
+ TDouble -> Fst ref
+ TArray _ t -> L.map (Lam (dual t) (go t)) ref
+ TNil -> Lit LNil
+ TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2))
+ TFun t1 t2 -> Lam t1
+ (App (unduale t2)
+ (App (sinkExp1 ref)
+ (App (duale t1)
+ (Var t1 Zero))))
+ where
+ ref :: Exp (Dual a ': env) (Dual a)
+ ref = Var (dual ty) Zero
+
+prfDualSht :: ShapeType sh -> sh :~: Dual sh
+prfDualSht STZ = Refl
+prfDualSht (STC sht)
+ | Refl <- prfDualSht sht
+ = Refl