diff options
Diffstat (limited to 'src/CHAD')
| -rw-r--r-- | src/CHAD/Top.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 8 | ||||
| -rw-r--r-- | src/CHAD/Types/ToTan.hs | 10 | 
3 files changed, 10 insertions, 10 deletions
| diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 9e7e7f5..261ddfe 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -49,11 +49,11 @@ d1Identity = \case    STNil -> Refl    STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl    STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl +  STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl    STMaybe t | Refl <- d1Identity t -> Refl    STArr _ t | Refl <- d1Identity t -> Refl    STScal _ -> Refl    STAccum{} -> error "Accumulators not allowed in input program" -  STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl  d1eIdentity :: SList STy env -> D1E env :~: env  d1eIdentity SNil = Refl diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 74e7dbd..974669d 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -11,19 +11,19 @@ type family D1 t where    D1 TNil = TNil    D1 (TPair a b) = TPair (D1 a) (D1 b)    D1 (TEither a b) = TEither (D1 a) (D1 b) +  D1 (TLEither a b) = TLEither (D1 a) (D1 b)    D1 (TMaybe a) = TMaybe (D1 a)    D1 (TArr n t) = TArr n (D1 t)    D1 (TScal t) = TScal t -  D1 (TLEither a b) = TLEither (D1 a) (D1 b)  type family D2 t where    D2 TNil = TNil    D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))    D2 (TEither a b) = TLEither (D2 a) (D2 b) +  D2 (TLEither a b) = TLEither (D2 a) (D2 b)    D2 (TMaybe t) = TMaybe (D2 t)    D2 (TArr n t) = TMaybe (TArr n (D2 t))    D2 (TScal t) = D2s t -  D2 (TLEither a b) = TLEither (D2 a) (D2 b)  type family D2s t where    D2s TI32 = TNil @@ -48,11 +48,11 @@ d1 :: STy t -> STy (D1 t)  d1 STNil = STNil  d1 (STPair a b) = STPair (d1 a) (d1 b)  d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b)  d1 (STMaybe t) = STMaybe (d1 t)  d1 (STArr n t) = STArr n (d1 t)  d1 (STScal t) = STScal t  d1 STAccum{} = error "Accumulators not allowed in input program" -d1 (STLEither a b) = STLEither (d1 a) (d1 b)  d1e :: SList STy env -> SList STy (D1E env)  d1e SNil = SNil @@ -62,6 +62,7 @@ d2M :: STy t -> SMTy (D2 t)  d2M STNil = SMTNil  d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))  d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)  d2M (STMaybe t) = SMTMaybe (d2M t)  d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))  d2M (STScal t) = case t of @@ -71,7 +72,6 @@ d2M (STScal t) = case t of    STF64 -> SMTScal STF64    STBool -> SMTNil  d2M STAccum{} = error "Accumulators not allowed in input program" -d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)  d2 :: STy t -> STy (D2 t)  d2 = fromSMTy . d2M diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 87c01cb..8476712 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -28,6 +28,11 @@ toTan typ primal der = case typ of                          (Left p, Left d') -> Left (toTan t1 p d')                          (Right p, Right d') -> Right (toTan t2 p d')                          _ -> error "Primal and cotangent disagree on Either alternative" +  STLEither t1 t2 -> case (primal, der) of +                       (_, Nothing) -> Nothing +                       (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) +                       (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) +                       _ -> error "Primal and cotangent disagree on LEither alternative"    STMaybe t -> liftA2 (toTan t) primal der    STArr _ t -> case der of                   Nothing -> arrayMap (zeroTan t) primal @@ -40,8 +45,3 @@ toTan typ primal der = case typ of    STScal sty -> case sty of      STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der    STAccum{} -> error "Accumulators not allowed in input program" -  STLEither t1 t2 -> case (primal, der) of -                       (_, Nothing) -> Nothing -                       (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) -                       (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) -                       _ -> error "Primal and cotangent disagree on LEither alternative" | 
