diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/ForwardAD.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/ForwardAD.hs')
| -rw-r--r-- | src/ForwardAD.hs | 270 |
1 files changed, 0 insertions, 270 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs deleted file mode 100644 index 6655423..0000000 --- a/src/ForwardAD.hs +++ /dev/null @@ -1,270 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module ForwardAD where - -import Data.Bifunctor (bimap) -import System.IO.Unsafe - --- import Debug.Trace --- import AST.Pretty - -import Array -import AST -import Compile -import Data -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep - - --- | Tangent along a type (coincides with cotangent for these types) -type family Tan t where - Tan TNil = TNil - Tan (TPair a b) = TPair (Tan a) (Tan b) - Tan (TEither a b) = TEither (Tan a) (Tan b) - Tan (TLEither a b) = TLEither (Tan a) (Tan b) - Tan (TMaybe t) = TMaybe (Tan t) - Tan (TArr n t) = TArr n (Tan t) - Tan (TScal t) = TanS t - -type family TanS t where - TanS TI32 = TNil - TanS TI64 = TNil - TanS TF32 = TScal TF32 - TanS TF64 = TScal TF64 - TanS TBool = TNil - -type family TanE env where - TanE '[] = '[] - TanE (t : env) = Tan t : TanE env - -tanty :: STy t -> STy (Tan t) -tanty STNil = STNil -tanty (STPair a b) = STPair (tanty a) (tanty b) -tanty (STEither a b) = STEither (tanty a) (tanty b) -tanty (STLEither a b) = STLEither (tanty a) (tanty b) -tanty (STMaybe t) = STMaybe (tanty t) -tanty (STArr n t) = STArr n (tanty t) -tanty (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -tanty STAccum{} = error "Accumulators not allowed in input program" - -tanenv :: SList STy env -> SList STy (TanE env) -tanenv SNil = SNil -tanenv (t `SCons` env) = tanty t `SCons` tanenv env - -zeroTan :: STy t -> Rep t -> Rep (Tan t) -zeroTan STNil () = () -zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) -zeroTan (STEither a _) (Left x) = Left (zeroTan a x) -zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) -zeroTan (STMaybe _) Nothing = Nothing -zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) -zeroTan (STArr _ t) x = fmap (zeroTan t) x -zeroTan (STScal STI32) _ = () -zeroTan (STScal STI64) _ = () -zeroTan (STScal STF32) _ = 0.0 -zeroTan (STScal STF64) _ = 0.0 -zeroTan (STScal STBool) _ = () -zeroTan STAccum{} _ = error "Accumulators not allowed in input program" - -tanScalars :: STy t -> Rep (Tan t) -> [Double] -tanScalars STNil () = [] -tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y -tanScalars (STEither a _) (Left x) = tanScalars a x -tanScalars (STEither _ b) (Right y) = tanScalars b y -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y -tanScalars (STMaybe _) Nothing = [] -tanScalars (STMaybe t) (Just x) = tanScalars t x -tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x -tanScalars (STScal STI32) _ = [] -tanScalars (STScal STI64) _ = [] -tanScalars (STScal STF32) x = [realToFrac x] -tanScalars (STScal STF64) x = [x] -tanScalars (STScal STBool) _ = [] -tanScalars STAccum{} _ = error "Accumulators not allowed in input program" - -tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] -tanEScalars SNil SNil = [] -tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs - -unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) -unzipDN STNil _ = ((), ()) -unzipDN (STPair a b) (d1, d2) = - let (x, dx) = unzipDN a d1 - (y, dy) = unzipDN b d2 - in ((x, y), (dx, dy)) -unzipDN (STEither a b) d = case d of - Left d1 -> bimap Left Left (unzipDN a d1) - Right d2 -> bimap Right Right (unzipDN b d2) -unzipDN (STLEither a b) d = case d of - Nothing -> (Nothing, Nothing) - Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) - Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) -unzipDN (STMaybe t) d = case d of - Nothing -> (Nothing, Nothing) - Just d' -> bimap Just Just (unzipDN t d') -unzipDN (STArr _ t) d = - let pairs = arrayMap (unzipDN t) d - in (arrayMap fst pairs, arrayMap snd pairs) -unzipDN (STScal ty) d = case ty of - STI32 -> (d, ()) - STI64 -> (d, ()) - STF32 -> d - STF64 -> d - STBool -> (d, ()) -unzipDN STAccum{} _ = error "Accumulators not allowed in input program" - -dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double -dotprodTan STNil _ _ = 0.0 -dotprodTan (STPair a b) (x, y) (x', y') = - dotprodTan a x x' + dotprodTan b y y' -dotprodTan (STEither a b) x y = case (x, y) of - (Left x', Left y') -> dotprodTan a x' y' - (Right x', Right y') -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible Either alternatives" -dotprodTan (STLEither a b) x y = case (x, y) of - (Nothing, _) -> 0.0 -- 0 * y = 0 - (_, Nothing) -> 0.0 -- x * 0 = 0 - (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' - (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible LEither alternatives" -dotprodTan (STMaybe t) x y = case (x, y) of - (Nothing, Nothing) -> 0.0 - (Just x', Just y') -> dotprodTan t x' y' - _ -> error "dotprodTan: incompatible Maybe alternatives" -dotprodTan (STArr _ t) x y = - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 - | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] - | otherwise -> error "dotprodTan: incompatible array shapes" -dotprodTan (STScal ty) x y = case ty of - STI32 -> 0.0 - STI64 -> 0.0 - STF32 -> realToFrac @Float @Double (x * y) - STF64 -> x * y - STBool -> 0.0 -dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" - --- -- Primal expression must be duplicable --- dnConstE :: STy t -> Ex env t -> Ex env (DN t) --- dnConstE STNil _ = ENil ext --- dnConstE (STPair t1 t2) e = --- -- This creates fst/snd stacks of unbounded size, but let's not care here --- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) --- dnConstE (STEither t1 t2) e = --- ECase ext e --- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) --- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) --- dnConstE (STMaybe t) e = --- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e --- dnConstE (STArr n t) e = --- EBuild ext n (EShape ext e) --- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) --- dnConstE (STScal t) e = case t of --- STI32 -> e --- STI64 -> e --- STF32 -> EPair ext e (EConst ext STF32 0.0) --- STF64 -> EPair ext e (EConst ext STF64 0.0) --- STBool -> e --- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" - -dnConst :: STy t -> Rep t -> Rep (DN t) -dnConst STNil = const () -dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) -dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -dnConst (STMaybe t) = fmap (dnConst t) -dnConst (STArr _ t) = arrayMap (dnConst t) -dnConst (STScal t) = case t of - STI32 -> id - STI64 -> id - STF32 -> (,0.0) - STF64 -> (,0.0) - STBool -> id -dnConst STAccum{} = error "Accumulators not allowed in input program" - --- | Given a function that computes the forward derivative for a particular --- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this --- @t@ input. -type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) - -dnOnehots :: STy t -> Rep t -> RevByFwd t -dnOnehots STNil _ = \_ -> () -dnOnehots (STPair t1 t2) (x, y) = - \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) -dnOnehots (STEither t1 t2) e = - case e of - Left x -> \f -> Left (dnOnehots t1 x (f . Left)) - Right y -> \f -> Right (dnOnehots t2 y (f . Right)) -dnOnehots (STLEither t1 t2) e = - case e of - Nothing -> \_ -> Nothing - Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) - Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) -dnOnehots (STMaybe t) m = - case m of - Nothing -> \_ -> Nothing - Just x -> \f -> Just (dnOnehots t x (f . Just)) -dnOnehots (STArr _ t) a = - \f -> - arrayGenerate (arrayShape a) $ \idx -> - dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> - if i == idx then oh else dnConst t (arrayIndex a i))) -dnOnehots (STScal t) x = case t of - STI32 -> \_ -> () - STI64 -> \_ -> () - STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) - STF64 -> \f -> f (x, 1.0) - STBool -> \_ -> () -dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" - -dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) -dnConstEnv SNil SNil = SNil -dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val - -type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) - -dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env -dnOnehotEnvs SNil SNil = \_ -> SNil -dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = - \f -> - Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) - `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) - -data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t)) - -makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t -makeFwdADArtifactInterp env expr = - let dexpr = dfwdDN expr - in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) - -{-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) -makeFwdADArtifactCompile env expr = do - (fun, output) <- compile (dne env) (dfwdDN expr) - return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) - -drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) -drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) - -drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) -drevByFwd (FwdADArtifact env outty fun) input dres = - dnOnehotEnvs env input $ \dnInput -> - -- trace (showEnv (dne env) dnInput) $ - let (_, outtan) = unzipDN outty (fun dnInput) - in dotprodTan outty outtan dres |
