aboutsummaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs270
1 files changed, 0 insertions, 270 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
deleted file mode 100644
index 6655423..0000000
--- a/src/ForwardAD.hs
+++ /dev/null
@@ -1,270 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-module ForwardAD where
-
-import Data.Bifunctor (bimap)
-import System.IO.Unsafe
-
--- import Debug.Trace
--- import AST.Pretty
-
-import Array
-import AST
-import Compile
-import Data
-import ForwardAD.DualNumbers
-import Interpreter
-import Interpreter.Rep
-
-
--- | Tangent along a type (coincides with cotangent for these types)
-type family Tan t where
- Tan TNil = TNil
- Tan (TPair a b) = TPair (Tan a) (Tan b)
- Tan (TEither a b) = TEither (Tan a) (Tan b)
- Tan (TLEither a b) = TLEither (Tan a) (Tan b)
- Tan (TMaybe t) = TMaybe (Tan t)
- Tan (TArr n t) = TArr n (Tan t)
- Tan (TScal t) = TanS t
-
-type family TanS t where
- TanS TI32 = TNil
- TanS TI64 = TNil
- TanS TF32 = TScal TF32
- TanS TF64 = TScal TF64
- TanS TBool = TNil
-
-type family TanE env where
- TanE '[] = '[]
- TanE (t : env) = Tan t : TanE env
-
-tanty :: STy t -> STy (Tan t)
-tanty STNil = STNil
-tanty (STPair a b) = STPair (tanty a) (tanty b)
-tanty (STEither a b) = STEither (tanty a) (tanty b)
-tanty (STLEither a b) = STLEither (tanty a) (tanty b)
-tanty (STMaybe t) = STMaybe (tanty t)
-tanty (STArr n t) = STArr n (tanty t)
-tanty (STScal t) = case t of
- STI32 -> STNil
- STI64 -> STNil
- STF32 -> STScal STF32
- STF64 -> STScal STF64
- STBool -> STNil
-tanty STAccum{} = error "Accumulators not allowed in input program"
-
-tanenv :: SList STy env -> SList STy (TanE env)
-tanenv SNil = SNil
-tanenv (t `SCons` env) = tanty t `SCons` tanenv env
-
-zeroTan :: STy t -> Rep t -> Rep (Tan t)
-zeroTan STNil () = ()
-zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
-zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
-zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
-zeroTan (STLEither _ _) Nothing = Nothing
-zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
-zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
-zeroTan (STMaybe _) Nothing = Nothing
-zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
-zeroTan (STArr _ t) x = fmap (zeroTan t) x
-zeroTan (STScal STI32) _ = ()
-zeroTan (STScal STI64) _ = ()
-zeroTan (STScal STF32) _ = 0.0
-zeroTan (STScal STF64) _ = 0.0
-zeroTan (STScal STBool) _ = ()
-zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
-
-tanScalars :: STy t -> Rep (Tan t) -> [Double]
-tanScalars STNil () = []
-tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
-tanScalars (STEither a _) (Left x) = tanScalars a x
-tanScalars (STEither _ b) (Right y) = tanScalars b y
-tanScalars (STLEither _ _) Nothing = []
-tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
-tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
-tanScalars (STMaybe _) Nothing = []
-tanScalars (STMaybe t) (Just x) = tanScalars t x
-tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
-tanScalars (STScal STI32) _ = []
-tanScalars (STScal STI64) _ = []
-tanScalars (STScal STF32) x = [realToFrac x]
-tanScalars (STScal STF64) x = [x]
-tanScalars (STScal STBool) _ = []
-tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
-
-tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
-tanEScalars SNil SNil = []
-tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs
-
-unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
-unzipDN STNil _ = ((), ())
-unzipDN (STPair a b) (d1, d2) =
- let (x, dx) = unzipDN a d1
- (y, dy) = unzipDN b d2
- in ((x, y), (dx, dy))
-unzipDN (STEither a b) d = case d of
- Left d1 -> bimap Left Left (unzipDN a d1)
- Right d2 -> bimap Right Right (unzipDN b d2)
-unzipDN (STLEither a b) d = case d of
- Nothing -> (Nothing, Nothing)
- Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
- Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
-unzipDN (STMaybe t) d = case d of
- Nothing -> (Nothing, Nothing)
- Just d' -> bimap Just Just (unzipDN t d')
-unzipDN (STArr _ t) d =
- let pairs = arrayMap (unzipDN t) d
- in (arrayMap fst pairs, arrayMap snd pairs)
-unzipDN (STScal ty) d = case ty of
- STI32 -> (d, ())
- STI64 -> (d, ())
- STF32 -> d
- STF64 -> d
- STBool -> (d, ())
-unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
-
-dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
-dotprodTan STNil _ _ = 0.0
-dotprodTan (STPair a b) (x, y) (x', y') =
- dotprodTan a x x' + dotprodTan b y y'
-dotprodTan (STEither a b) x y = case (x, y) of
- (Left x', Left y') -> dotprodTan a x' y'
- (Right x', Right y') -> dotprodTan b x' y'
- _ -> error "dotprodTan: incompatible Either alternatives"
-dotprodTan (STLEither a b) x y = case (x, y) of
- (Nothing, _) -> 0.0 -- 0 * y = 0
- (_, Nothing) -> 0.0 -- x * 0 = 0
- (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
- (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
- _ -> error "dotprodTan: incompatible LEither alternatives"
-dotprodTan (STMaybe t) x y = case (x, y) of
- (Nothing, Nothing) -> 0.0
- (Just x', Just y') -> dotprodTan t x' y'
- _ -> error "dotprodTan: incompatible Maybe alternatives"
-dotprodTan (STArr _ t) x y =
- let sh1 = arrayShape x
- sh2 = arrayShape y
- in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
- | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
- | otherwise -> error "dotprodTan: incompatible array shapes"
-dotprodTan (STScal ty) x y = case ty of
- STI32 -> 0.0
- STI64 -> 0.0
- STF32 -> realToFrac @Float @Double (x * y)
- STF64 -> x * y
- STBool -> 0.0
-dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
-
--- -- Primal expression must be duplicable
--- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
--- dnConstE STNil _ = ENil ext
--- dnConstE (STPair t1 t2) e =
--- -- This creates fst/snd stacks of unbounded size, but let's not care here
--- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
--- dnConstE (STEither t1 t2) e =
--- ECase ext e
--- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
--- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
--- dnConstE (STMaybe t) e =
--- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
--- dnConstE (STArr n t) e =
--- EBuild ext n (EShape ext e)
--- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
--- dnConstE (STScal t) e = case t of
--- STI32 -> e
--- STI64 -> e
--- STF32 -> EPair ext e (EConst ext STF32 0.0)
--- STF64 -> EPair ext e (EConst ext STF64 0.0)
--- STBool -> e
--- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"
-
-dnConst :: STy t -> Rep t -> Rep (DN t)
-dnConst STNil = const ()
-dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
-dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
-dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
-dnConst (STMaybe t) = fmap (dnConst t)
-dnConst (STArr _ t) = arrayMap (dnConst t)
-dnConst (STScal t) = case t of
- STI32 -> id
- STI64 -> id
- STF32 -> (,0.0)
- STF64 -> (,0.0)
- STBool -> id
-dnConst STAccum{} = error "Accumulators not allowed in input program"
-
--- | Given a function that computes the forward derivative for a particular
--- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
--- @t@ input.
-type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)
-
-dnOnehots :: STy t -> Rep t -> RevByFwd t
-dnOnehots STNil _ = \_ -> ()
-dnOnehots (STPair t1 t2) (x, y) =
- \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
-dnOnehots (STEither t1 t2) e =
- case e of
- Left x -> \f -> Left (dnOnehots t1 x (f . Left))
- Right y -> \f -> Right (dnOnehots t2 y (f . Right))
-dnOnehots (STLEither t1 t2) e =
- case e of
- Nothing -> \_ -> Nothing
- Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
- Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
-dnOnehots (STMaybe t) m =
- case m of
- Nothing -> \_ -> Nothing
- Just x -> \f -> Just (dnOnehots t x (f . Just))
-dnOnehots (STArr _ t) a =
- \f ->
- arrayGenerate (arrayShape a) $ \idx ->
- dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
- if i == idx then oh else dnConst t (arrayIndex a i)))
-dnOnehots (STScal t) x = case t of
- STI32 -> \_ -> ()
- STI64 -> \_ -> ()
- STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
- STF64 -> \f -> f (x, 1.0)
- STBool -> \_ -> ()
-dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
-
-dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
-dnConstEnv SNil SNil = SNil
-dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val
-
-type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)
-
-dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
-dnOnehotEnvs SNil SNil = \_ -> SNil
-dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
- \f ->
- Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
- `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
-
-data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))
-
-makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
-makeFwdADArtifactInterp env expr =
- let dexpr = dfwdDN expr
- in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
-
-{-# NOINLINE makeFwdADArtifactCompile #-}
-makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String)
-makeFwdADArtifactCompile env expr = do
- (fun, output) <- compile (dne env) (dfwdDN expr)
- return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output)
-
-drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
-drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
-
-drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
-drevByFwd (FwdADArtifact env outty fun) input dres =
- dnOnehotEnvs env input $ \dnInput ->
- -- trace (showEnv (dne env) dnInput) $
- let (_, outtan) = unzipDN outty (fun dnInput)
- in dotprodTan outty outtan dres