diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/ForwardAD.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/ForwardAD.hs')
| -rw-r--r-- | src/CHAD/ForwardAD.hs | 270 |
1 files changed, 270 insertions, 0 deletions
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs new file mode 100644 index 0000000..7126e10 --- /dev/null +++ b/src/CHAD/ForwardAD.hs @@ -0,0 +1,270 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.ForwardAD where + +import Data.Bifunctor (bimap) +import System.IO.Unsafe + +-- import Debug.Trace +-- import CHAD.AST.Pretty + +import CHAD.Array +import CHAD.AST +import CHAD.Compile +import CHAD.Data +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep + + +-- | Tangent along a type (coincides with cotangent for these types) +type family Tan t where + Tan TNil = TNil + Tan (TPair a b) = TPair (Tan a) (Tan b) + Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) + Tan (TMaybe t) = TMaybe (Tan t) + Tan (TArr n t) = TArr n (Tan t) + Tan (TScal t) = TanS t + +type family TanS t where + TanS TI32 = TNil + TanS TI64 = TNil + TanS TF32 = TScal TF32 + TanS TF64 = TScal TF64 + TanS TBool = TNil + +type family TanE env where + TanE '[] = '[] + TanE (t : env) = Tan t : TanE env + +tanty :: STy t -> STy (Tan t) +tanty STNil = STNil +tanty (STPair a b) = STPair (tanty a) (tanty b) +tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) +tanty (STMaybe t) = STMaybe (tanty t) +tanty (STArr n t) = STArr n (tanty t) +tanty (STScal t) = case t of + STI32 -> STNil + STI64 -> STNil + STF32 -> STScal STF32 + STF64 -> STScal STF64 + STBool -> STNil +tanty STAccum{} = error "Accumulators not allowed in input program" + +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env + +zeroTan :: STy t -> Rep t -> Rep (Tan t) +zeroTan STNil () = () +zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) +zeroTan (STEither a _) (Left x) = Left (zeroTan a x) +zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) +zeroTan (STMaybe _) Nothing = Nothing +zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) +zeroTan (STArr _ t) x = fmap (zeroTan t) x +zeroTan (STScal STI32) _ = () +zeroTan (STScal STI64) _ = () +zeroTan (STScal STF32) _ = 0.0 +zeroTan (STScal STF64) _ = 0.0 +zeroTan (STScal STBool) _ = () +zeroTan STAccum{} _ = error "Accumulators not allowed in input program" + +tanScalars :: STy t -> Rep (Tan t) -> [Double] +tanScalars STNil () = [] +tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y +tanScalars (STEither a _) (Left x) = tanScalars a x +tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y +tanScalars (STMaybe _) Nothing = [] +tanScalars (STMaybe t) (Just x) = tanScalars t x +tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STScal STI32) _ = [] +tanScalars (STScal STI64) _ = [] +tanScalars (STScal STF32) x = [realToFrac x] +tanScalars (STScal STF64) x = [x] +tanScalars (STScal STBool) _ = [] +tanScalars STAccum{} _ = error "Accumulators not allowed in input program" + +tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] +tanEScalars SNil SNil = [] +tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs + +unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) +unzipDN STNil _ = ((), ()) +unzipDN (STPair a b) (d1, d2) = + let (x, dx) = unzipDN a d1 + (y, dy) = unzipDN b d2 + in ((x, y), (dx, dy)) +unzipDN (STEither a b) d = case d of + Left d1 -> bimap Left Left (unzipDN a d1) + Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) +unzipDN (STMaybe t) d = case d of + Nothing -> (Nothing, Nothing) + Just d' -> bimap Just Just (unzipDN t d') +unzipDN (STArr _ t) d = + let pairs = arrayMap (unzipDN t) d + in (arrayMap fst pairs, arrayMap snd pairs) +unzipDN (STScal ty) d = case ty of + STI32 -> (d, ()) + STI64 -> (d, ()) + STF32 -> d + STF64 -> d + STBool -> (d, ()) +unzipDN STAccum{} _ = error "Accumulators not allowed in input program" + +dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double +dotprodTan STNil _ _ = 0.0 +dotprodTan (STPair a b) (x, y) (x', y') = + dotprodTan a x x' + dotprodTan b y y' +dotprodTan (STEither a b) x y = case (x, y) of + (Left x', Left y') -> dotprodTan a x' y' + (Right x', Right y') -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" +dotprodTan (STMaybe t) x y = case (x, y) of + (Nothing, Nothing) -> 0.0 + (Just x', Just y') -> dotprodTan t x' y' + _ -> error "dotprodTan: incompatible Maybe alternatives" +dotprodTan (STArr _ t) x y = + let sh1 = arrayShape x + sh2 = arrayShape y + in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 + | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] + | otherwise -> error "dotprodTan: incompatible array shapes" +dotprodTan (STScal ty) x y = case ty of + STI32 -> 0.0 + STI64 -> 0.0 + STF32 -> realToFrac @Float @Double (x * y) + STF64 -> x * y + STBool -> 0.0 +dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" + +-- -- Primal expression must be duplicable +-- dnConstE :: STy t -> Ex env t -> Ex env (DN t) +-- dnConstE STNil _ = ENil ext +-- dnConstE (STPair t1 t2) e = +-- -- This creates fst/snd stacks of unbounded size, but let's not care here +-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) +-- dnConstE (STEither t1 t2) e = +-- ECase ext e +-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) +-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) +-- dnConstE (STMaybe t) e = +-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e +-- dnConstE (STArr n t) e = +-- EBuild ext n (EShape ext e) +-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) +-- dnConstE (STScal t) e = case t of +-- STI32 -> e +-- STI64 -> e +-- STF32 -> EPair ext e (EConst ext STF32 0.0) +-- STF64 -> EPair ext e (EConst ext STF64 0.0) +-- STBool -> e +-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" + +dnConst :: STy t -> Rep t -> Rep (DN t) +dnConst STNil = const () +dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) +dnConst (STMaybe t) = fmap (dnConst t) +dnConst (STArr _ t) = arrayMap (dnConst t) +dnConst (STScal t) = case t of + STI32 -> id + STI64 -> id + STF32 -> (,0.0) + STF64 -> (,0.0) + STBool -> id +dnConst STAccum{} = error "Accumulators not allowed in input program" + +-- | Given a function that computes the forward derivative for a particular +-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this +-- @t@ input. +type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) + +dnOnehots :: STy t -> Rep t -> RevByFwd t +dnOnehots STNil _ = \_ -> () +dnOnehots (STPair t1 t2) (x, y) = + \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) +dnOnehots (STEither t1 t2) e = + case e of + Left x -> \f -> Left (dnOnehots t1 x (f . Left)) + Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) +dnOnehots (STMaybe t) m = + case m of + Nothing -> \_ -> Nothing + Just x -> \f -> Just (dnOnehots t x (f . Just)) +dnOnehots (STArr _ t) a = + \f -> + arrayGenerate (arrayShape a) $ \idx -> + dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> + if i == idx then oh else dnConst t (arrayIndex a i))) +dnOnehots (STScal t) x = case t of + STI32 -> \_ -> () + STI64 -> \_ -> () + STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) + STF64 -> \f -> f (x, 1.0) + STBool -> \_ -> () +dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" + +dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) +dnConstEnv SNil SNil = SNil +dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val + +type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) + +dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env +dnOnehotEnvs SNil SNil = \_ -> SNil +dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = + \f -> + Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) + `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) + +data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t)) + +makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t +makeFwdADArtifactInterp env expr = + let dexpr = dfwdDN expr + in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) + +{-# NOINLINE makeFwdADArtifactCompile #-} +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do + (fun, output) <- compile (dne env) (dfwdDN expr) + return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) + +drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) + +drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd (FwdADArtifact env outty fun) input dres = + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ + let (_, outtan) = unzipDN outty (fun dnInput) + in dotprodTan outty outtan dres |
