aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/ForwardAD.hs')
-rw-r--r--src/CHAD/ForwardAD.hs270
1 files changed, 270 insertions, 0 deletions
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs
new file mode 100644
index 0000000..7126e10
--- /dev/null
+++ b/src/CHAD/ForwardAD.hs
@@ -0,0 +1,270 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.ForwardAD where
+
+import Data.Bifunctor (bimap)
+import System.IO.Unsafe
+
+-- import Debug.Trace
+-- import CHAD.AST.Pretty
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.Compile
+import CHAD.Data
+import CHAD.ForwardAD.DualNumbers
+import CHAD.Interpreter
+import CHAD.Interpreter.Rep
+
+
+-- | Tangent along a type (coincides with cotangent for these types)
+type family Tan t where
+ Tan TNil = TNil
+ Tan (TPair a b) = TPair (Tan a) (Tan b)
+ Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
+ Tan (TMaybe t) = TMaybe (Tan t)
+ Tan (TArr n t) = TArr n (Tan t)
+ Tan (TScal t) = TanS t
+
+type family TanS t where
+ TanS TI32 = TNil
+ TanS TI64 = TNil
+ TanS TF32 = TScal TF32
+ TanS TF64 = TScal TF64
+ TanS TBool = TNil
+
+type family TanE env where
+ TanE '[] = '[]
+ TanE (t : env) = Tan t : TanE env
+
+tanty :: STy t -> STy (Tan t)
+tanty STNil = STNil
+tanty (STPair a b) = STPair (tanty a) (tanty b)
+tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
+tanty (STMaybe t) = STMaybe (tanty t)
+tanty (STArr n t) = STArr n (tanty t)
+tanty (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+tanty STAccum{} = error "Accumulators not allowed in input program"
+
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
+
+zeroTan :: STy t -> Rep t -> Rep (Tan t)
+zeroTan STNil () = ()
+zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
+zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
+zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
+zeroTan (STMaybe _) Nothing = Nothing
+zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
+zeroTan (STArr _ t) x = fmap (zeroTan t) x
+zeroTan (STScal STI32) _ = ()
+zeroTan (STScal STI64) _ = ()
+zeroTan (STScal STF32) _ = 0.0
+zeroTan (STScal STF64) _ = 0.0
+zeroTan (STScal STBool) _ = ()
+zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+
+tanScalars :: STy t -> Rep (Tan t) -> [Double]
+tanScalars STNil () = []
+tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
+tanScalars (STEither a _) (Left x) = tanScalars a x
+tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
+tanScalars (STMaybe _) Nothing = []
+tanScalars (STMaybe t) (Just x) = tanScalars t x
+tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
+tanScalars (STScal STI32) _ = []
+tanScalars (STScal STI64) _ = []
+tanScalars (STScal STF32) x = [realToFrac x]
+tanScalars (STScal STF64) x = [x]
+tanScalars (STScal STBool) _ = []
+tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+
+tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
+tanEScalars SNil SNil = []
+tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs
+
+unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
+unzipDN STNil _ = ((), ())
+unzipDN (STPair a b) (d1, d2) =
+ let (x, dx) = unzipDN a d1
+ (y, dy) = unzipDN b d2
+ in ((x, y), (dx, dy))
+unzipDN (STEither a b) d = case d of
+ Left d1 -> bimap Left Left (unzipDN a d1)
+ Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
+unzipDN (STMaybe t) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just d' -> bimap Just Just (unzipDN t d')
+unzipDN (STArr _ t) d =
+ let pairs = arrayMap (unzipDN t) d
+ in (arrayMap fst pairs, arrayMap snd pairs)
+unzipDN (STScal ty) d = case ty of
+ STI32 -> (d, ())
+ STI64 -> (d, ())
+ STF32 -> d
+ STF64 -> d
+ STBool -> (d, ())
+unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+
+dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
+dotprodTan STNil _ _ = 0.0
+dotprodTan (STPair a b) (x, y) (x', y') =
+ dotprodTan a x x' + dotprodTan b y y'
+dotprodTan (STEither a b) x y = case (x, y) of
+ (Left x', Left y') -> dotprodTan a x' y'
+ (Right x', Right y') -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
+dotprodTan (STMaybe t) x y = case (x, y) of
+ (Nothing, Nothing) -> 0.0
+ (Just x', Just y') -> dotprodTan t x' y'
+ _ -> error "dotprodTan: incompatible Maybe alternatives"
+dotprodTan (STArr _ t) x y =
+ let sh1 = arrayShape x
+ sh2 = arrayShape y
+ in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
+ | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
+ | otherwise -> error "dotprodTan: incompatible array shapes"
+dotprodTan (STScal ty) x y = case ty of
+ STI32 -> 0.0
+ STI64 -> 0.0
+ STF32 -> realToFrac @Float @Double (x * y)
+ STF64 -> x * y
+ STBool -> 0.0
+dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+
+-- -- Primal expression must be duplicable
+-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
+-- dnConstE STNil _ = ENil ext
+-- dnConstE (STPair t1 t2) e =
+-- -- This creates fst/snd stacks of unbounded size, but let's not care here
+-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
+-- dnConstE (STEither t1 t2) e =
+-- ECase ext e
+-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
+-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
+-- dnConstE (STMaybe t) e =
+-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
+-- dnConstE (STArr n t) e =
+-- EBuild ext n (EShape ext e)
+-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
+-- dnConstE (STScal t) e = case t of
+-- STI32 -> e
+-- STI64 -> e
+-- STF32 -> EPair ext e (EConst ext STF32 0.0)
+-- STF64 -> EPair ext e (EConst ext STF64 0.0)
+-- STBool -> e
+-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConst :: STy t -> Rep t -> Rep (DN t)
+dnConst STNil = const ()
+dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
+dnConst (STMaybe t) = fmap (dnConst t)
+dnConst (STArr _ t) = arrayMap (dnConst t)
+dnConst (STScal t) = case t of
+ STI32 -> id
+ STI64 -> id
+ STF32 -> (,0.0)
+ STF64 -> (,0.0)
+ STBool -> id
+dnConst STAccum{} = error "Accumulators not allowed in input program"
+
+-- | Given a function that computes the forward derivative for a particular
+-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
+-- @t@ input.
+type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)
+
+dnOnehots :: STy t -> Rep t -> RevByFwd t
+dnOnehots STNil _ = \_ -> ()
+dnOnehots (STPair t1 t2) (x, y) =
+ \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
+dnOnehots (STEither t1 t2) e =
+ case e of
+ Left x -> \f -> Left (dnOnehots t1 x (f . Left))
+ Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
+dnOnehots (STMaybe t) m =
+ case m of
+ Nothing -> \_ -> Nothing
+ Just x -> \f -> Just (dnOnehots t x (f . Just))
+dnOnehots (STArr _ t) a =
+ \f ->
+ arrayGenerate (arrayShape a) $ \idx ->
+ dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
+ if i == idx then oh else dnConst t (arrayIndex a i)))
+dnOnehots (STScal t) x = case t of
+ STI32 -> \_ -> ()
+ STI64 -> \_ -> ()
+ STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
+ STF64 -> \f -> f (x, 1.0)
+ STBool -> \_ -> ()
+dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
+dnConstEnv SNil SNil = SNil
+dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val
+
+type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)
+
+dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
+dnOnehotEnvs SNil SNil = \_ -> SNil
+dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
+ \f ->
+ Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
+ `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
+
+data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))
+
+makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
+makeFwdADArtifactInterp env expr =
+ let dexpr = dfwdDN expr
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
+
+{-# NOINLINE makeFwdADArtifactCompile #-}
+makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String)
+makeFwdADArtifactCompile env expr = do
+ (fun, output) <- compile (dne env) (dfwdDN expr)
+ return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output)
+
+drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
+
+drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd (FwdADArtifact env outty fun) input dres =
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
+ let (_, outtan) = unzipDN outty (fun dnInput)
+ in dotprodTan outty outtan dres