diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 20:37:06 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 20:38:05 +0200 |
commit | d0eb9a1edfb4233d557d954f46685f25382234d8 (patch) | |
tree | 04eb5a746258fcaa2a3b98228c6eadb2b0178ba3 /src | |
parent | 4ad7eaba73d5fda8ff5028d1e53966f728d704d3 (diff) |
Reorder TLEither to after TEither
Diffstat (limited to 'src')
-rw-r--r-- | src/AST.hs | 14 | ||||
-rw-r--r-- | src/AST/Accum.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 2 | ||||
-rw-r--r-- | src/AST/SplitLets.hs | 4 | ||||
-rw-r--r-- | src/AST/Types.hs | 14 | ||||
-rw-r--r-- | src/Analysis/Identity.hs | 4 | ||||
-rw-r--r-- | src/CHAD.hs | 2 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 2 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 8 | ||||
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 10 | ||||
-rw-r--r-- | src/Compile.hs | 72 | ||||
-rw-r--r-- | src/ForwardAD.hs | 48 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers/Types.hs | 4 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 14 | ||||
-rw-r--r-- | src/Simplify.hs | 2 |
15 files changed, 101 insertions, 101 deletions
@@ -55,10 +55,6 @@ data Expr x env t where ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b - ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) - ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) - ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) - ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -99,6 +95,12 @@ data Expr x env t where EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t + -- interface of abstract monoidal types + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + -- partiality EError :: x a -> STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) @@ -376,11 +378,11 @@ class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy class KnownMTy t where knownMTy :: SMTy t instance KnownMTy TNil where knownMTy = SMTNil @@ -398,11 +400,11 @@ styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- smtyKnown t = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict smtyKnown :: SMTy t -> Dict (KnownMTy t) smtyKnown SMTNil = Dict diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index e84034b..03369c8 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -79,6 +79,7 @@ lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil lemZeroInfoD2 STNil = Refl lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl lemZeroInfoD2 (STScal STI32) = Refl @@ -87,7 +88,6 @@ lemZeroInfoD2 (STScal STF32) = Refl lemZeroInfoD2 (STScal STF64) = Refl lemZeroInfoD2 (STScal STBool) = Refl lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index e09f3ae..2bb78d4 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -388,6 +388,7 @@ ppSTy' :: Int -> STy t -> Doc q ppSTy' _ STNil = ppString "1" ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t ppSTy' d (STArr n t) = ppParen (d > 10) $ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t @@ -398,7 +399,6 @@ ppSTy' _ (STScal sty) = ppString $ case sty of STF64 -> "f64" STBool -> "bool" ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t -ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b ppSMTy :: Int -> SMTy t -> String ppSMTy d ty = render $ ppSMTy' d ty diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 159934d..1379e35 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -123,11 +123,11 @@ split typ = case typ of STPair{} -> splitRec (EVar ext typ IZ) typ STNil -> other STEither{} -> other + STLEither{} -> other STMaybe{} -> other STArr{} -> other STScal{} -> other STAccum{} -> other - STLEither{} -> other where other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) other = (Point typ IZ, BTop) @@ -142,11 +142,11 @@ splitRec rhs typ = case typ of (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) STEither{} -> other + STLEither{} -> other STMaybe{} -> other STArr{} -> other STScal{} -> other STAccum{} -> other - STLEither{} -> other where other :: (Pointers (t : env) t, Bindings Ex env '[t]) other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index efb1e04..a3b7302 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -23,12 +23,11 @@ type data Ty = TNil | TPair Ty Ty | TEither Ty Ty + | TLEither Ty Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Ty -- ^ contained type must be a monoid type - -- sparse monoid types - | TLEither Ty Ty type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -37,12 +36,11 @@ data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) + STLEither :: STy a -> STy b -> STy (TLEither a b) STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) - -- sparse monoid types - STLEither :: STy a -> STy b -> STy (TLEither a b) deriving instance Show (STy t) instance GCompare STy where @@ -53,6 +51,8 @@ instance GCompare STy where STPair{} _ -> GLT ; _ STPair{} -> GGT (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') STEither{} _ -> GLT ; _ STEither{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STLEither{} _ -> GLT ; _ STLEither{} -> GGT (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') @@ -60,9 +60,7 @@ instance GCompare STy where (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - STAccum{} _ -> GLT ; _ STAccum{} -> GGT - (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT + -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq @@ -177,11 +175,11 @@ hasArrays :: STy t' -> Bool hasArrays STNil = False hasArrays (STPair a b) = hasArrays a || hasArrays b hasArrays (STEither a b) = hasArrays a || hasArrays b +hasArrays (STLEither a b) = hasArrays a || hasArrays b hasArrays (STMaybe t) = hasArrays t hasArrays STArr{} = True hasArrays STScal{} = False hasArrays STAccum{} = True -hasArrays (STLEither a b) = hasArrays a || hasArrays b type family Tup env where Tup '[] = TNil diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 20575b3..a1a6376 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -28,9 +28,9 @@ data ValId t where VIPair :: ValId a -> ValId b -> ValId (TPair a b) VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value - VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) @@ -367,11 +367,11 @@ genIds :: STy t -> IdGen (ValId t) genIds STNil = pure VINil genIds (STPair a b) = VIPair <$> genIds a <*> genIds b genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId -genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) shidsToVec SZ _ = pure VNil diff --git a/src/CHAD.hs b/src/CHAD.hs index ac308ac..3a7b907 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -484,6 +484,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STNil -> True STPair a b -> isDiscrete a && isDiscrete b STEither a b -> isDiscrete a && isDiscrete b + STLEither a b -> isDiscrete a && isDiscrete b STMaybe a -> isDiscrete a STArr _ a -> isDiscrete a STScal st -> case st of @@ -493,7 +494,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STF64 -> False STBool -> True STAccum{} -> False - STLEither a b -> isDiscrete a && isDiscrete b ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 9e7e7f5..261ddfe 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -49,11 +49,11 @@ d1Identity = \case STNil -> Refl STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl STMaybe t | Refl <- d1Identity t -> Refl STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" - STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 74e7dbd..974669d 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -11,19 +11,19 @@ type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t - D1 (TLEither a b) = TLEither (D1 a) (D1 b) type family D2 t where D2 TNil = TNil D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TMaybe (TArr n (D2 t)) D2 (TScal t) = D2s t - D2 (TLEither a b) = TLEither (D2 a) (D2 b) type family D2s t where D2s TI32 = TNil @@ -48,11 +48,11 @@ d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" -d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil @@ -62,6 +62,7 @@ d2M :: STy t -> SMTy (D2 t) d2M STNil = SMTNil d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2M (STMaybe t) = SMTMaybe (d2M t) d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) d2M (STScal t) = case t of @@ -71,7 +72,6 @@ d2M (STScal t) = case t of STF64 -> SMTScal STF64 STBool -> SMTNil d2M STAccum{} = error "Accumulators not allowed in input program" -d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2 :: STy t -> STy (D2 t) d2 = fromSMTy . d2M diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 87c01cb..8476712 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -28,6 +28,11 @@ toTan typ primal der = case typ of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der STArr _ t -> case der of Nothing -> arrayMap (zeroTan t) primal @@ -40,8 +45,3 @@ toTan typ primal der = case typ of STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" - STLEither t1 t2 -> case (primal, der) of - (_, Nothing) -> Nothing - (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) - (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) - _ -> error "Primal and cotangent disagree on LEither alternative" diff --git a/src/Compile.hs b/src/Compile.hs index 6ba3a39..cd10831 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -221,6 +221,7 @@ genStructName = \t -> "ty_" ++ gen t where gen STNil = "n" gen (STPair a b) = 'P' : gen a ++ gen b gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b gen (STMaybe t) = 'M' : gen t gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t gen (STScal st) = case st of @@ -230,7 +231,6 @@ genStructName = \t -> "ty_" ++ gen t where STF64 -> "d" STBool -> "b" gen (STAccum t) = 'C' : gen (fromSMTy t) - gen (STLEither a b) = 'L' : gen a ++ gen b -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the @@ -246,6 +246,8 @@ genStruct name topty = case topty of [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] STEither a b -> -- 0 -> l, 1 -> r [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] STMaybe t -> -- 0 -> nothing, 1 -> just [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] STArr n t -> @@ -259,8 +261,6 @@ genStruct name topty = case topty of STAccum t -> [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" ,StructDecl name (name ++ "_buf *buf;") com] - STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] where com = ppSTy 0 topty @@ -282,11 +282,11 @@ genStructs ty = do STNil -> pure () STPair a b -> genStructs a >> genStructs b STEither a b -> genStructs a >> genStructs b + STLEither a b -> genStructs a >> genStructs b STMaybe t -> genStructs t STArr _ t -> genStructs t STScal _ -> pure () STAccum t -> genStructs (fromSMTy t) - STLEither a b -> genStructs a >> genStructs b tell (BList (genStruct name ty)) @@ -463,6 +463,15 @@ serialise topty topval ptr off k = (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k (STMaybe _, Nothing) -> do pokeByteOff ptr off (0 :: Word8) k @@ -493,15 +502,6 @@ serialise topty topval ptr off k = STF64 -> pokeByteOff ptr off (x :: Double) >> k STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k (STAccum{}, _) -> error "Cannot serialise accumulators" - (STLEither _ _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STLEither a _, Just (Left x)) -> do - pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STLEither _ b, Just (Right y)) -> do - pokeByteOff ptr off (2 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k -- | Assumes that this is called at the correct alignment. deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -518,6 +518,13 @@ deserialise topty ptr off = if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" STMaybe t -> do tag <- peekByteOff @Word8 ptr off if tag == 0 @@ -541,13 +548,6 @@ deserialise topty ptr off = STF64 -> peekByteOff @Double ptr off STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off STAccum{} -> error "Cannot serialise accumulators" - STLEither a b -> do - tag <- peekByteOff @Word8 ptr off - case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) - 0 -> return Nothing - 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) - 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) - _ -> error "Invalid tag value" align :: Int -> Int -> Int align a off = (off + a - 1) `div` a * a @@ -569,6 +569,10 @@ metricsSTy (STEither a b) = let (a1, s1) = metricsSTy a (a2, s2) = metricsSTy b in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned @@ -580,10 +584,6 @@ metricsSTy (STScal sty) = case sty of STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t metricsSTy (STAccum t) = metricsSTy (fromSMTy t) -metricsSTy (STLEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -1071,8 +1071,8 @@ compile' env = \case incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend emit $ SAsg v (CELit addend) -- sparse types - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") + (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") -- dense types (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do f (v++".a") (i++".a") @@ -1303,13 +1303,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop - (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1657,6 +1657,14 @@ zeroRefcountCheck toptyp opname topvar = go (STEither a b) path = do (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) go (STMaybe a) path = do ss <- go a (path++".j") return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty @@ -1673,14 +1681,6 @@ zeroRefcountCheck toptyp opname topvar = return (BList [s1, s2, s3]) go STScal{} _ = empty go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" - go (STLEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ - SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) - s1 - (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) - s2 - mempty)) combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) combine (MaybeT a) (MaybeT b) = MaybeT $ do diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 5756f96..b353def 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -26,10 +26,10 @@ type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t - Tan (TLEither a b) = TLEither (Tan a) (Tan b) type family TanS t where TanS TI32 = TNil @@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t) tanty STNil = STNil tanty (STPair a b) = STPair (tanty a) (tanty b) tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanty (STMaybe t) = STMaybe (tanty t) tanty (STArr n t) = STArr n (tanty t) tanty (STScal t) = case t of @@ -55,7 +56,6 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" -tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanenv :: SList STy env -> SList STy (TanE env) tanenv SNil = SNil @@ -66,6 +66,9 @@ zeroTan STNil () = () zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) zeroTan (STEither a _) (Left x) = Left (zeroTan a x) zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) zeroTan (STMaybe _) Nothing = Nothing zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y tanScalars (STEither a _) (Left x) = tanScalars a x tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x @@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) = unzipDN (STEither a b) d = case d of Left d1 -> bimap Left Left (unzipDN a d1) Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) unzipDN (STMaybe t) d = case d of Nothing -> (Nothing, Nothing) Just d' -> bimap Just Just (unzipDN t d') @@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" -unzipDN (STLEither a b) d = case d of - Nothing -> (Nothing, Nothing) - Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) - Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of (Left x', Left y') -> dotprodTan a x' y' (Right x', Right y') -> dotprodTan b x' y' _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" dotprodTan (STMaybe t) x y = case (x, y) of (Nothing, Nothing) -> 0.0 (Just x', Just y') -> dotprodTan t x' y' @@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -dotprodTan (STLEither a b) x y = case (x, y) of - (Nothing, _) -> 0.0 -- 0 * y = 0 - (_, Nothing) -> 0.0 -- x * 0 = 0 - (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' - (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible LEither alternatives" -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t) dnConst STNil = const () dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) dnConst (STMaybe t) = fmap (dnConst t) dnConst (STArr _ t) = arrayMap (dnConst t) dnConst (STScal t) = case t of @@ -196,7 +197,6 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e = case e of Left x -> \f -> Left (dnOnehots t1 x (f . Left)) Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnOnehots (STMaybe t) m = case m of Nothing -> \_ -> Nothing @@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" -dnOnehots (STLEither t1 t2) e = - case e of - Nothing -> \_ -> Nothing - Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) - Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs index 3c76cbe..dcacf5f 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -12,10 +12,10 @@ type family DN t where DN TNil = TNil DN (TPair a b) = TPair (DN a) (DN b) DN (TEither a b) = TEither (DN a) (DN b) + DN (TLEither a b) = TLEither (DN a) (DN b) DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t - DN (TLEither a b) = TLEither (DN a) (DN b) type family DNS t where DNS TF32 = TPair (TScal TF32) (TScal TF32) @@ -32,6 +32,7 @@ dn :: STy t -> STy (DN t) dn STNil = STNil dn (STPair a b) = STPair (dn a) (dn b) dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b) dn (STMaybe t) = STMaybe (dn t) dn (STArr n t) = STArr n (dn t) dn (STScal t) = case t of @@ -41,7 +42,6 @@ dn (STScal t) = case t of STI64 -> STScal STI64 STBool -> STScal STBool dn STAccum{} = error "Accum in source program" -dn (STLEither a b) = STLEither (dn a) (dn b) dne :: SList STy env -> SList STy (DNE env) dne SNil = SNil diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 070ba4c..1682303 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -22,11 +22,11 @@ type family Rep t where Rep TNil = () Rep (TPair a b) = (Rep a, Rep b) Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) Rep (TMaybe t) = Maybe (Rep t) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty Rep (TAccum t) = RepAc t - Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) -- Mutable, represents monoid types t. type family RepAc t where @@ -56,6 +56,9 @@ showValue _ STNil () = showString "()" showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x showValue d (STArr _ t) arr = showParen (d > 10) $ @@ -70,9 +73,6 @@ showValue d (STScal sty) x = case sty of STI64 -> showsPrec d x STBool -> showsPrec d x showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">" -showValue _ (STLEither _ _) Nothing = showString "LNil" -showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x -showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" @@ -86,6 +86,9 @@ rnfRep STNil () = () rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y rnfRep (STEither a _) (Left x) = rnfRep a x rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y rnfRep (STMaybe _) Nothing = () rnfRep (STMaybe t) (Just x) = rnfRep t x rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = @@ -97,9 +100,6 @@ rnfRep (STScal t) x = case t of STF64 -> rnf x STBool -> rnf x rnfRep STAccum{} _ = error "Cannot rnf accumulators" -rnfRep (STLEither _ _) Nothing = () -rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x -rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y instance KnownTy t => NFData (Value t) where rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/Simplify.hs b/src/Simplify.hs index f5eb0a1..6f97e6d 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -359,11 +359,11 @@ checkAccumInScope = \case SNil -> False check STNil = False check (STPair s t) = check s || check t check (STEither s t) = check s || check t + check (STLEither s t) = check s || check t check (STMaybe t) = check t check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True - check (STLEither s t) = check s || check t data OneHotTerm env p a b where OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b |