diff options
Diffstat (limited to 'src/ForwardAD.hs')
| -rw-r--r-- | src/ForwardAD.hs | 28 | 
1 files changed, 28 insertions, 0 deletions
| diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b7036dd..5756f96 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -29,6 +29,7 @@ type family Tan t where    Tan (TMaybe t) = TMaybe (Tan t)    Tan (TArr n t) = TArr n (Tan t)    Tan (TScal t) = TanS t +  Tan (TLEither a b) = TLEither (Tan a) (Tan b)  type family TanS t where    TanS TI32 = TNil @@ -54,6 +55,11 @@ tanty (STScal t) = case t of    STF64 -> STScal STF64    STBool -> STNil  tanty STAccum{} = error "Accumulators not allowed in input program" +tanty (STLEither a b) = STLEither (tanty a) (tanty b) + +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env  zeroTan :: STy t -> Rep t -> Rep (Tan t)  zeroTan STNil () = () @@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0  zeroTan (STScal STF64) _ = 0.0  zeroTan (STScal STBool) _ = ()  zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))  tanScalars :: STy t -> Rep (Tan t) -> [Double]  tanScalars STNil () = [] @@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x]  tanScalars (STScal STF64) x = [x]  tanScalars (STScal STBool) _ = []  tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y  tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]  tanEScalars SNil SNil = [] @@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of    STF64 -> d    STBool -> (d, ())  unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN (STLEither a b) d = case d of +  Nothing -> (Nothing, Nothing) +  Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) +  Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)  dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double  dotprodTan STNil _ _ = 0.0 @@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of    STF64 -> x * y    STBool -> 0.0  dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan (STLEither a b) x y = case (x, y) of +  (Nothing, _) -> 0.0  -- 0 * y = 0 +  (_, Nothing) -> 0.0  -- x * 0 = 0 +  (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' +  (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' +  _ -> error "dotprodTan: incompatible LEither alternatives"  -- -- Primal expression must be duplicable  -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -174,6 +196,7 @@ dnConst (STScal t) = case t of    STF64 -> (,0.0)    STBool -> id  dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))  -- | Given a function that computes the forward derivative for a particular  -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of    STF64 -> \f -> f (x, 1.0)    STBool -> \_ -> ()  dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots (STLEither t1 t2) e = +  case e of +    Nothing -> \_ -> Nothing +    Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) +    Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))  dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)  dnConstEnv SNil SNil = SNil | 
