summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 20:37:06 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 20:38:05 +0200
commitd0eb9a1edfb4233d557d954f46685f25382234d8 (patch)
tree04eb5a746258fcaa2a3b98228c6eadb2b0178ba3 /src
parent4ad7eaba73d5fda8ff5028d1e53966f728d704d3 (diff)
Reorder TLEither to after TEither
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs14
-rw-r--r--src/AST/Accum.hs2
-rw-r--r--src/AST/Pretty.hs2
-rw-r--r--src/AST/SplitLets.hs4
-rw-r--r--src/AST/Types.hs14
-rw-r--r--src/Analysis/Identity.hs4
-rw-r--r--src/CHAD.hs2
-rw-r--r--src/CHAD/Top.hs2
-rw-r--r--src/CHAD/Types.hs8
-rw-r--r--src/CHAD/Types/ToTan.hs10
-rw-r--r--src/Compile.hs72
-rw-r--r--src/ForwardAD.hs48
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs4
-rw-r--r--src/Interpreter/Rep.hs14
-rw-r--r--src/Simplify.hs2
15 files changed, 101 insertions, 101 deletions
diff --git a/src/AST.hs b/src/AST.hs
index b2f5ce7..ca66e87 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -55,10 +55,6 @@ data Expr x env t where
ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
- ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
- ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
- ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
- ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
@@ -99,6 +95,12 @@ data Expr x env t where
EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t
+ -- interface of abstract monoidal types
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
+ ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
@@ -376,11 +378,11 @@ class KnownTy t where knownTy :: STy t
instance KnownTy TNil where knownTy = STNil
instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
-instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
class KnownMTy t where knownMTy :: SMTy t
instance KnownMTy TNil where knownMTy = SMTNil
@@ -398,11 +400,11 @@ styKnown :: STy t -> Dict (KnownTy t)
styKnown STNil = Dict
styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- smtyKnown t = Dict
-styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
smtyKnown :: SMTy t -> Dict (KnownMTy t)
smtyKnown SMTNil = Dict
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index e84034b..03369c8 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -79,6 +79,7 @@ lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil
lemZeroInfoD2 STNil = Refl
lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl
lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl
lemZeroInfoD2 (STScal STI32) = Refl
@@ -87,7 +88,6 @@ lemZeroInfoD2 (STScal STF32) = Refl
lemZeroInfoD2 (STScal STF64) = Refl
lemZeroInfoD2 (STScal STBool) = Refl
lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program"
-lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index e09f3ae..2bb78d4 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -388,6 +388,7 @@ ppSTy' :: Int -> STy t -> Doc q
ppSTy' _ STNil = ppString "1"
ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
ppSTy' d (STArr n t) = ppParen (d > 10) $
ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
@@ -398,7 +399,6 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF64 -> "f64"
STBool -> "bool"
ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
-ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
ppSMTy :: Int -> SMTy t -> String
ppSMTy d ty = render $ ppSMTy' d ty
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 159934d..1379e35 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -123,11 +123,11 @@ split typ = case typ of
STPair{} -> splitRec (EVar ext typ IZ) typ
STNil -> other
STEither{} -> other
+ STLEither{} -> other
STMaybe{} -> other
STArr{} -> other
STScal{} -> other
STAccum{} -> other
- STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
@@ -142,11 +142,11 @@ splitRec rhs typ = case typ of
(p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
STEither{} -> other
+ STLEither{} -> other
STMaybe{} -> other
STArr{} -> other
STScal{} -> other
STAccum{} -> other
- STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex env '[t])
other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index efb1e04..a3b7302 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -23,12 +23,11 @@ type data Ty
= TNil
| TPair Ty Ty
| TEither Ty Ty
+ | TLEither Ty Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
| TAccum Ty -- ^ contained type must be a monoid type
- -- sparse monoid types
- | TLEither Ty Ty
type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -37,12 +36,11 @@ data STy t where
STNil :: STy TNil
STPair :: STy a -> STy b -> STy (TPair a b)
STEither :: STy a -> STy b -> STy (TEither a b)
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
STMaybe :: STy a -> STy (TMaybe a)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
- -- sparse monoid types
- STLEither :: STy a -> STy b -> STy (TLEither a b)
deriving instance Show (STy t)
instance GCompare STy where
@@ -53,6 +51,8 @@ instance GCompare STy where
STPair{} _ -> GLT ; _ STPair{} -> GGT
(STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STLEither{} _ -> GLT ; _ STLEither{} -> GGT
(STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
(STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
@@ -60,9 +60,7 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
- (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
@@ -177,11 +175,11 @@ hasArrays :: STy t' -> Bool
hasArrays STNil = False
hasArrays (STPair a b) = hasArrays a || hasArrays b
hasArrays (STEither a b) = hasArrays a || hasArrays b
+hasArrays (STLEither a b) = hasArrays a || hasArrays b
hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = True
-hasArrays (STLEither a b) = hasArrays a || hasArrays b
type family Tup env where
Tup '[] = TNil
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 20575b3..a1a6376 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -28,9 +28,9 @@ data ValId t where
VIPair :: ValId a -> ValId b -> ValId (TPair a b)
VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
- VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
@@ -367,11 +367,11 @@ genIds :: STy t -> IdGen (ValId t)
genIds STNil = pure VINil
genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
genIds STAccum{} = VIAccum <$> genId
-genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
shidsToVec SZ _ = pure VNil
diff --git a/src/CHAD.hs b/src/CHAD.hs
index ac308ac..3a7b907 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -484,6 +484,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
STNil -> True
STPair a b -> isDiscrete a && isDiscrete b
STEither a b -> isDiscrete a && isDiscrete b
+ STLEither a b -> isDiscrete a && isDiscrete b
STMaybe a -> isDiscrete a
STArr _ a -> isDiscrete a
STScal st -> case st of
@@ -493,7 +494,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
STF64 -> False
STBool -> True
STAccum{} -> False
- STLEither a b -> isDiscrete a && isDiscrete b
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 9e7e7f5..261ddfe 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -49,11 +49,11 @@ d1Identity = \case
STNil -> Refl
STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
STMaybe t | Refl <- d1Identity t -> Refl
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
- STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 74e7dbd..974669d 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -11,19 +11,19 @@ type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
D1 (TEither a b) = TEither (D1 a) (D1 b)
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
- D1 (TLEither a b) = TLEither (D1 a) (D1 b)
type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
D2 (TEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TMaybe (TArr n (D2 t))
D2 (TScal t) = D2s t
- D2 (TLEither a b) = TLEither (D2 a) (D2 b)
type family D2s t where
D2s TI32 = TNil
@@ -48,11 +48,11 @@ d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
-d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
@@ -62,6 +62,7 @@ d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
d2M (STScal t) = case t of
@@ -71,7 +72,6 @@ d2M (STScal t) = case t of
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
-d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index 87c01cb..8476712 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -28,6 +28,11 @@ toTan typ primal der = case typ of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
+ STLEither t1 t2 -> case (primal, der) of
+ (_, Nothing) -> Nothing
+ (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
+ (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
+ _ -> error "Primal and cotangent disagree on LEither alternative"
STMaybe t -> liftA2 (toTan t) primal der
STArr _ t -> case der of
Nothing -> arrayMap (zeroTan t) primal
@@ -40,8 +45,3 @@ toTan typ primal der = case typ of
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
- STLEither t1 t2 -> case (primal, der) of
- (_, Nothing) -> Nothing
- (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
- (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
- _ -> error "Primal and cotangent disagree on LEither alternative"
diff --git a/src/Compile.hs b/src/Compile.hs
index 6ba3a39..cd10831 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -221,6 +221,7 @@ genStructName = \t -> "ty_" ++ gen t where
gen STNil = "n"
gen (STPair a b) = 'P' : gen a ++ gen b
gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STLEither a b) = 'L' : gen a ++ gen b
gen (STMaybe t) = 'M' : gen t
gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
gen (STScal st) = case st of
@@ -230,7 +231,6 @@ genStructName = \t -> "ty_" ++ gen t where
STF64 -> "d"
STBool -> "b"
gen (STAccum t) = 'C' : gen (fromSMTy t)
- gen (STLEither a b) = 'L' : gen a ++ gen b
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
@@ -246,6 +246,8 @@ genStruct name topty = case topty of
[StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
STEither a b -> -- 0 -> l, 1 -> r
[StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
STMaybe t -> -- 0 -> nothing, 1 -> just
[StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
STArr n t ->
@@ -259,8 +261,6 @@ genStruct name topty = case topty of
STAccum t ->
[StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""
,StructDecl name (name ++ "_buf *buf;") com]
- STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
where
com = ppSTy 0 topty
@@ -282,11 +282,11 @@ genStructs ty = do
STNil -> pure ()
STPair a b -> genStructs a >> genStructs b
STEither a b -> genStructs a >> genStructs b
+ STLEither a b -> genStructs a >> genStructs b
STMaybe t -> genStructs t
STArr _ t -> genStructs t
STScal _ -> pure ()
STAccum t -> genStructs (fromSMTy t)
- STLEither a b -> genStructs a >> genStructs b
tell (BList (genStruct name ty))
@@ -463,6 +463,15 @@ serialise topty topval ptr off k =
(STEither _ b, Right y) -> do
pokeByteOff ptr off (1 :: Word8)
serialise b y ptr (off + alignmentSTy topty) k
+ (STLEither _ _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STLEither a _, Just (Left x)) -> do
+ pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STLEither _ b, Just (Right y)) -> do
+ pokeByteOff ptr off (2 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
(STMaybe _, Nothing) -> do
pokeByteOff ptr off (0 :: Word8)
k
@@ -493,15 +502,6 @@ serialise topty topval ptr off k =
STF64 -> pokeByteOff ptr off (x :: Double) >> k
STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
(STAccum{}, _) -> error "Cannot serialise accumulators"
- (STLEither _ _, Nothing) -> do
- pokeByteOff ptr off (0 :: Word8)
- k
- (STLEither a _, Just (Left x)) -> do
- pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b)
- serialise a x ptr (off + alignmentSTy topty) k
- (STLEither _ b, Just (Right y)) -> do
- pokeByteOff ptr off (2 :: Word8)
- serialise b y ptr (off + alignmentSTy topty) k
-- | Assumes that this is called at the correct alignment.
deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
@@ -518,6 +518,13 @@ deserialise topty ptr off =
if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b)
then Left <$> deserialise a ptr (off + alignmentSTy topty)
else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STLEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
+ 0 -> return Nothing
+ 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
+ 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
+ _ -> error "Invalid tag value"
STMaybe t -> do
tag <- peekByteOff @Word8 ptr off
if tag == 0
@@ -541,13 +548,6 @@ deserialise topty ptr off =
STF64 -> peekByteOff @Double ptr off
STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
STAccum{} -> error "Cannot serialise accumulators"
- STLEither a b -> do
- tag <- peekByteOff @Word8 ptr off
- case tag of -- alignment of (union {a b}) is the same as alignment of (a + b)
- 0 -> return Nothing
- 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty)
- 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty)
- _ -> error "Invalid tag value"
align :: Int -> Int -> Int
align a off = (off + a - 1) `div` a * a
@@ -569,6 +569,10 @@ metricsSTy (STEither a b) =
let (a1, s1) = metricsSTy a
(a2, s2) = metricsSTy b
in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STLEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
@@ -580,10 +584,6 @@ metricsSTy (STScal sty) = case sty of
STF64 -> (8, 8)
STBool -> (1, 1) -- compiled to uint8_t
metricsSTy (STAccum t) = metricsSTy (fromSMTy t)
-metricsSTy (STLEither a b) =
- let (a1, s1) = metricsSTy a
- (a2, s2) = metricsSTy b
- in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
pokeShape ptr off = go . fromSNat
@@ -1071,8 +1071,8 @@ compile' env = \case
incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
emit $ SAsg v (CELit addend)
-- sparse types
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
(SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
-- dense types
(SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
f (v++".a") (i++".a")
@@ -1303,13 +1303,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))
(smartATProj "b" (makeArrayTree b))
makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))
(smartATProj "r" (makeArrayTree b))
+makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
+ (smartATProj "l" (makeArrayTree a))
+ (smartATProj "r" (makeArrayTree b))
makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))
makeArrayTree (STArr n t) = ATArray (Some n) (Some t)
makeArrayTree (STScal _) = ATNoop
makeArrayTree (STAccum _) = ATNoop
-makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop
- (smartATProj "l" (makeArrayTree a))
- (smartATProj "r" (makeArrayTree b))
incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()
incrementVar' marker inc path (ATArray (Some n) (Some eltty)) =
@@ -1657,6 +1657,14 @@ zeroRefcountCheck toptyp opname topvar =
go (STEither a b) path = do
(s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STLEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $
+ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
+ s1
+ (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
+ s2
+ mempty))
go (STMaybe a) path = do
ss <- go a (path++".j")
return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
@@ -1673,14 +1681,6 @@ zeroRefcountCheck toptyp opname topvar =
return (BList [s1, s2, s3])
go STScal{} _ = empty
go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
- go (STLEither a b) path = do
- (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
- return $ pure $
- SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1"))
- s1
- (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2"))
- s2
- mempty))
combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
combine (MaybeT a) (MaybeT b) = MaybeT $ do
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 5756f96..b353def 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -26,10 +26,10 @@ type family Tan t where
Tan TNil = TNil
Tan (TPair a b) = TPair (Tan a) (Tan b)
Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
- Tan (TLEither a b) = TLEither (Tan a) (Tan b)
type family TanS t where
TanS TI32 = TNil
@@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t)
tanty STNil = STNil
tanty (STPair a b) = STPair (tanty a) (tanty b)
tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
tanty (STMaybe t) = STMaybe (tanty t)
tanty (STArr n t) = STArr n (tanty t)
tanty (STScal t) = case t of
@@ -55,7 +56,6 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
-tanty (STLEither a b) = STLEither (tanty a) (tanty b)
tanenv :: SList STy env -> SList STy (TanE env)
tanenv SNil = SNil
@@ -66,6 +66,9 @@ zeroTan STNil () = ()
zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
zeroTan (STMaybe _) Nothing = Nothing
zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
zeroTan (STArr _ t) x = fmap (zeroTan t) x
@@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
-zeroTan (STLEither _ _) Nothing = Nothing
-zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
-zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
tanScalars (STEither a _) (Left x) = tanScalars a x
tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanScalars (STMaybe _) Nothing = []
tanScalars (STMaybe t) (Just x) = tanScalars t x
tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
@@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
-tanScalars (STLEither _ _) Nothing = []
-tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
-tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) =
unzipDN (STEither a b) d = case d of
Left d1 -> bimap Left Left (unzipDN a d1)
Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
unzipDN (STMaybe t) d = case d of
Nothing -> (Nothing, Nothing)
Just d' -> bimap Just Just (unzipDN t d')
@@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
-unzipDN (STLEither a b) d = case d of
- Nothing -> (Nothing, Nothing)
- Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
- Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of
(Left x', Left y') -> dotprodTan a x' y'
(Right x', Right y') -> dotprodTan b x' y'
_ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
dotprodTan (STMaybe t) x y = case (x, y) of
(Nothing, Nothing) -> 0.0
(Just x', Just y') -> dotprodTan t x' y'
@@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
-dotprodTan (STLEither a b) x y = case (x, y) of
- (Nothing, _) -> 0.0 -- 0 * y = 0
- (_, Nothing) -> 0.0 -- x * 0 = 0
- (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
- (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
- _ -> error "dotprodTan: incompatible LEither alternatives"
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t)
dnConst STNil = const ()
dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
dnConst (STMaybe t) = fmap (dnConst t)
dnConst (STArr _ t) = arrayMap (dnConst t)
dnConst (STScal t) = case t of
@@ -196,7 +197,6 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
-dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e =
case e of
Left x -> \f -> Left (dnOnehots t1 x (f . Left))
Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnOnehots (STMaybe t) m =
case m of
Nothing -> \_ -> Nothing
@@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
-dnOnehots (STLEither t1 t2) e =
- case e of
- Nothing -> \_ -> Nothing
- Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
- Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
index 3c76cbe..dcacf5f 100644
--- a/src/ForwardAD/DualNumbers/Types.hs
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -12,10 +12,10 @@ type family DN t where
DN TNil = TNil
DN (TPair a b) = TPair (DN a) (DN b)
DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TLEither a b) = TLEither (DN a) (DN b)
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
- DN (TLEither a b) = TLEither (DN a) (DN b)
type family DNS t where
DNS TF32 = TPair (TScal TF32) (TScal TF32)
@@ -32,6 +32,7 @@ dn :: STy t -> STy (DN t)
dn STNil = STNil
dn (STPair a b) = STPair (dn a) (dn b)
dn (STEither a b) = STEither (dn a) (dn b)
+dn (STLEither a b) = STLEither (dn a) (dn b)
dn (STMaybe t) = STMaybe (dn t)
dn (STArr n t) = STArr n (dn t)
dn (STScal t) = case t of
@@ -41,7 +42,6 @@ dn (STScal t) = case t of
STI64 -> STScal STI64
STBool -> STScal STBool
dn STAccum{} = error "Accum in source program"
-dn (STLEither a b) = STLEither (dn a) (dn b)
dne :: SList STy env -> SList STy (DNE env)
dne SNil = SNil
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 070ba4c..1682303 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -22,11 +22,11 @@ type family Rep t where
Rep TNil = ()
Rep (TPair a b) = (Rep a, Rep b)
Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))
Rep (TMaybe t) = Maybe (Rep t)
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
Rep (TAccum t) = RepAc t
- Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))
-- Mutable, represents monoid types t.
type family RepAc t where
@@ -56,6 +56,9 @@ showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y
+showValue _ (STLEither _ _) Nothing = showString "LNil"
+showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
+showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
showValue d (STArr _ t) arr = showParen (d > 10) $
@@ -70,9 +73,6 @@ showValue d (STScal sty) x = case sty of
STI64 -> showsPrec d x
STBool -> showsPrec d x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
-showValue _ (STLEither _ _) Nothing = showString "LNil"
-showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
-showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
@@ -86,6 +86,9 @@ rnfRep STNil () = ()
rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y
rnfRep (STEither a _) (Left x) = rnfRep a x
rnfRep (STEither _ b) (Right y) = rnfRep b y
+rnfRep (STLEither _ _) Nothing = ()
+rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
+rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
rnfRep (STMaybe _) Nothing = ()
rnfRep (STMaybe t) (Just x) = rnfRep t x
rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr =
@@ -97,9 +100,6 @@ rnfRep (STScal t) x = case t of
STF64 -> rnf x
STBool -> rnf x
rnfRep STAccum{} _ = error "Cannot rnf accumulators"
-rnfRep (STLEither _ _) Nothing = ()
-rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
-rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
instance KnownTy t => NFData (Value t) where
rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/Simplify.hs b/src/Simplify.hs
index f5eb0a1..6f97e6d 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -359,11 +359,11 @@ checkAccumInScope = \case SNil -> False
check STNil = False
check (STPair s t) = check s || check t
check (STEither s t) = check s || check t
+ check (STLEither s t) = check s || check t
check (STMaybe t) = check t
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
- check (STLEither s t) = check s || check t
data OneHotTerm env p a b where
OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b