summaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs377
1 files changed, 182 insertions, 195 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 0a9e12c..63244a8 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -1,202 +1,189 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+module ForwardAD where
--- I want to bring various type variables in scope using type annotations in
--- patterns, but I don't want to have to mention all the other type parameters
--- of the types in question as well then. Partial type signatures (with '_') are
--- useful here.
-{-# LANGUAGE PartialTypeSignatures #-}
-{-# OPTIONS -Wno-partial-type-signatures #-}
-module ForwardAD (
- dfwd,
- FD, FDS, FDE, fd,
-) where
+import Data.Bifunctor (bimap)
+-- import Data.Foldable (toList)
+import Array
import AST
+-- import AST.Bindings
import Data
-
-
--- | Dual-numbers transformation
-type family FD t where
- FD TNil = TNil
- FD (TPair a b) = TPair (FD a) (FD b)
- FD (TEither a b) = TEither (FD a) (FD b)
- FD (TMaybe t) = TMaybe (FD t)
- FD (TArr n t) = TArr n (FD t)
- FD (TScal t) = FDS t
-
-type family FDS t where
- FDS TF32 = TPair (TScal TF32) (TScal TF32)
- FDS TF64 = TPair (TScal TF64) (TScal TF64)
- FDS TI32 = TScal TI32
- FDS TI64 = TScal TI64
- FDS TBool = TScal TBool
-
-type family FDE env where
- FDE '[] = '[]
- FDE (t : ts) = FD t : FDE ts
-
-fd :: STy t -> STy (FD t)
-fd STNil = STNil
-fd (STPair a b) = STPair (fd a) (fd b)
-fd (STEither a b) = STEither (fd a) (fd b)
-fd (STMaybe t) = STMaybe (fd t)
-fd (STArr n t) = STArr n (fd t)
-fd (STScal t) = case t of
- STF32 -> STPair (STScal STF32) (STScal STF32)
- STF64 -> STPair (STScal STF64) (STScal STF64)
- STI32 -> STScal STI32
- STI64 -> STScal STI64
- STBool -> STScal STBool
-fd STAccum{} = error "Accum in source program"
-
-fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
-fdPreservesTupIx SZ = Refl
-fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl
-
-convIdx :: Idx env t -> Idx (FDE env) (FD t)
-convIdx IZ = IZ
-convIdx (IS i) = IS (convIdx i)
-
-scalTyCase :: SScalTy t
- -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r)
- -> (FD (TScal t) ~ TScal t => r)
- -> r
-scalTyCase STF32 k1 _ = k1
-scalTyCase STF64 k1 _ = k1
-scalTyCase STI32 _ k2 = k2
-scalTyCase STI64 _ k2 = k2
-scalTyCase STBool _ k2 = k2
-
--- | Argument does not need to be duplicable.
-dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b)
-dop = \case
- OAdd t -> scalTyCase t
- (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
- (EOp ext (OAdd t))
- OMul t -> scalTyCase t
- (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
- (EOp ext (OMul t))
- ONeg t -> scalTyCase t
- (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
- (EOp ext (ONeg t))
- OLt t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
- (EOp ext (OLt t))
- OLe t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
- (EOp ext (OLe t))
- OEq t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
- (EOp ext (OEq t))
- ONot -> EOp ext ONot
- OIf -> EOp ext OIf
- where
- add :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
- add t a b = EOp ext (OAdd t) (EPair ext a b)
-
- mul :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
- mul t a b = EOp ext (OMul t) (EPair ext a b)
-
- neg :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
- neg t = EOp ext (ONeg t)
-
- unFloat :: FD a ~ TPair a a
- => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b))
- -> Ex env (FD a) -> Ex env (FD b)
- unFloat f e =
- ELet ext e $
- let var = EVar ext (typeOf e) IZ
- in f (EFst ext var, ESnd ext var)
-
- binFloat :: (a ~ TPair s s, FD s ~ TPair s s)
- => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b))
- -> Ex env (FD a) -> Ex env (FD b)
- binFloat f e =
- ELet ext e $
- let var = EVar ext (typeOf e) IZ
- in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
- (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
-
-dfwd :: Ex env t -> Ex (FDE env) (FD t)
-dfwd = \case
- EVar _ t i -> EVar ext (fd t) (convIdx i)
- ELet _ a b -> ELet ext (dfwd a) (dfwd b)
- EPair _ a b -> EPair ext (dfwd a) (dfwd b)
- EFst _ e -> EFst ext (dfwd e)
- ESnd _ e -> ESnd ext (dfwd e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext (fd t) (dfwd e)
- EInr _ t e -> EInr ext (fd t) (dfwd e)
- ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b)
- ENothing _ t -> ENothing ext (fd t)
- EJust _ e -> EJust ext (dfwd e)
- EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b)
- EConstArr _ n t x -> scalTyCase t
- (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
- (EConstArr ext n t x))
- (EConstArr ext n t x)
- EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b)
- EBuild _ n a b
- | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b)
- EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b)
- ESum1Inner _ e ->
- let STArr n (STScal t) = typeOf e
- pairty = (STPair (STScal t) (STScal t))
- in scalTyCase t
- (ELet ext (dfwd e) $
- ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
- (EVar ext (STArr n pairty) IZ)))
- (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
- (EVar ext (STArr n pairty) IZ))))
- (ESum1Inner ext (dfwd e))
- EUnit _ e -> EUnit ext (dfwd e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b)
- EConst _ t x -> scalTyCase t
- (EPair ext (EConst ext t x) (EConst ext t 0.0))
- (EConst ext t x)
- EIdx0 _ e -> EIdx0 ext (dfwd e)
- EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b)
- EIdx _ n a b
- | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b)
- EShape _ e
- | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e)
- EOp _ op e -> dop op (dfwd e)
- EError t s -> EError (fd t) s
-
- EWith{} -> err_accum
- EAccum{} -> err_accum
- EZero{} -> err_monoid
- EPlus{} -> err_monoid
- where
- err_accum = error "Accumulator operations unsupported in the source program"
- err_monoid = error "Monoid operations unsupported in the source program"
-
-emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
-emap f arr =
- let STArr n t = typeOf arr
- in ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
-
-ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip a b =
- let STArr n t1 = typeOf a
- STArr _ t2 = typeOf b
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- (EIdx ext n (EVar ext (STArr n t2) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+import ForwardAD.DualNumbers
+import Interpreter
+import Interpreter.Rep
+
+
+-- | Tangent along a type (coincides with cotangent for these types)
+type family Tan t where
+ Tan TNil = TNil
+ Tan (TPair a b) = TPair (Tan a) (Tan b)
+ Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TMaybe t) = TMaybe (Tan t)
+ Tan (TArr n t) = TArr n (Tan t)
+ Tan (TScal t) = TanS t
+
+type family TanS t where
+ TanS TI32 = TNil
+ TanS TI64 = TNil
+ TanS TF32 = TScal TF32
+ TanS TF64 = TScal TF64
+ TanS TBool = TNil
+
+type family TanE env where
+ TanE '[] = '[]
+ TanE (t : env) = Tan t : TanE env
+
+tanty :: STy t -> STy (Tan t)
+tanty STNil = STNil
+tanty (STPair a b) = STPair (tanty a) (tanty b)
+tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STMaybe t) = STMaybe (tanty t)
+tanty (STArr n t) = STArr n (tanty t)
+tanty (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+tanty STAccum{} = error "Accumulators not allowed in input program"
+
+unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
+unzipDN STNil _ = ((), ())
+unzipDN (STPair a b) (d1, d2) =
+ let (x, dx) = unzipDN a d1
+ (y, dy) = unzipDN b d2
+ in ((x, y), (dx, dy))
+unzipDN (STEither a b) d = case d of
+ Left d1 -> bimap Left Left (unzipDN a d1)
+ Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STMaybe t) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just d' -> bimap Just Just (unzipDN t d')
+unzipDN (STArr _ t) d =
+ let pairs = arrayMap (unzipDN t) d
+ in (arrayMap fst pairs, arrayMap snd pairs)
+unzipDN (STScal ty) d = case ty of
+ STI32 -> (d, ())
+ STI64 -> (d, ())
+ STF32 -> d
+ STF64 -> d
+ STBool -> (d, ())
+unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+
+dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
+dotprodTan STNil _ _ = 0.0
+dotprodTan (STPair a b) (x, y) (x', y') =
+ dotprodTan a x x' + dotprodTan b y y'
+dotprodTan (STEither a b) x y = case (x, y) of
+ (Left x', Left y') -> dotprodTan a x' y'
+ (Right x', Right y') -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STMaybe t) x y = case (x, y) of
+ (Nothing, Nothing) -> 0.0
+ (Just x', Just y') -> dotprodTan t x' y'
+ _ -> error "dotprodTan: incompatible Maybe alternatives"
+dotprodTan (STArr _ t) x y =
+ let sh1 = arrayShape x
+ sh2 = arrayShape y
+ in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
+ | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
+ | otherwise -> error "dotprodTan: incompatible array shapes"
+dotprodTan (STScal ty) x y = case ty of
+ STI32 -> 0.0
+ STI64 -> 0.0
+ STF32 -> realToFrac @Float @Double (x * y)
+ STF64 -> x * y
+ STBool -> 0.0
+dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+
+-- -- Primal expression must be duplicable
+-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
+-- dnConstE STNil _ = ENil ext
+-- dnConstE (STPair t1 t2) e =
+-- -- This creates fst/snd stacks of unbounded size, but let's not care here
+-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
+-- dnConstE (STEither t1 t2) e =
+-- ECase ext e
+-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
+-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
+-- dnConstE (STMaybe t) e =
+-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
+-- dnConstE (STArr n t) e =
+-- EBuild ext n (EShape ext e)
+-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
+-- dnConstE (STScal t) e = case t of
+-- STI32 -> e
+-- STI64 -> e
+-- STF32 -> EPair ext e (EConst ext STF32 0.0)
+-- STF64 -> EPair ext e (EConst ext STF64 0.0)
+-- STBool -> e
+-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConst :: STy t -> Rep t -> Rep (DN t)
+dnConst STNil = const ()
+dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STMaybe t) = fmap (dnConst t)
+dnConst (STArr _ t) = arrayMap (dnConst t)
+dnConst (STScal t) = case t of
+ STI32 -> id
+ STI64 -> id
+ STF32 -> (,0.0)
+ STF64 -> (,0.0)
+ STBool -> id
+dnConst STAccum{} = error "Accumulators not allowed in input program"
+
+-- | Given a function that computes the forward derivative for a particular
+-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
+-- @t@ input.
+type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)
+
+dnOnehots :: STy t -> Rep t -> RevByFwd t
+dnOnehots STNil _ = \_ -> ()
+dnOnehots (STPair t1 t2) (x, y) =
+ \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
+dnOnehots (STEither t1 t2) e =
+ case e of
+ Left x -> \f -> Left (dnOnehots t1 x (f . Left))
+ Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STMaybe t) m =
+ case m of
+ Nothing -> \_ -> Nothing
+ Just x -> \f -> Just (dnOnehots t x (f . Just))
+dnOnehots (STArr _ t) a =
+ \f ->
+ arrayGenerate (arrayShape a) $ \idx ->
+ dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
+ if i == idx then oh else dnConst t (arrayIndex a i)))
+dnOnehots (STScal t) x = case t of
+ STI32 -> \_ -> ()
+ STI64 -> \_ -> ()
+ STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
+ STF64 -> \f -> f (x, 1.0)
+ STBool -> \_ -> ()
+dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
+dnConstEnv SNil SNil = SNil
+dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val
+
+type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)
+
+dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
+dnOnehotEnvs SNil SNil = \_ -> SNil
+dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
+ \f ->
+ Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
+ `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
+
+drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd env expr input dres =
+ let outty = typeOf expr
+ in dnOnehotEnvs env input $ \dnInput ->
+ let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr))
+ in dotprodTan outty outtan dres