diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 14 | ||||
| -rw-r--r-- | src/AST/Accum.hs | 2 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 2 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 4 | ||||
| -rw-r--r-- | src/AST/Types.hs | 14 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 4 | ||||
| -rw-r--r-- | src/CHAD.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Top.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 8 | ||||
| -rw-r--r-- | src/CHAD/Types/ToTan.hs | 10 | ||||
| -rw-r--r-- | src/Compile.hs | 72 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 48 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers/Types.hs | 4 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 14 | ||||
| -rw-r--r-- | src/Simplify.hs | 2 | 
15 files changed, 101 insertions, 101 deletions
| @@ -55,10 +55,6 @@ data Expr x env t where    ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)    EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)    EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b -  ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) -  ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) -  ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) -  ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c    -- array operations    EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -99,6 +95,12 @@ data Expr x env t where    EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t    EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t +  -- interface of abstract monoidal types +  ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) +  ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) +  ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) +  ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c +    -- partiality    EError :: x a -> STy a -> String -> Expr x env a  deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) @@ -376,11 +378,11 @@ class KnownTy t where knownTy :: STy t  instance KnownTy TNil where knownTy = STNil  instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy  instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy  instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy  instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy  instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy  instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy  class KnownMTy t where knownMTy :: SMTy t  instance KnownMTy TNil where knownMTy = SMTNil @@ -398,11 +400,11 @@ styKnown :: STy t -> Dict (KnownTy t)  styKnown STNil = Dict  styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict  styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict  styKnown (STMaybe t) | Dict <- styKnown t = Dict  styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict  styKnown (STScal t) | Dict <- sscaltyKnown t = Dict  styKnown (STAccum t) | Dict <- smtyKnown t = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict  smtyKnown :: SMTy t -> Dict (KnownMTy t)  smtyKnown SMTNil = Dict diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index e84034b..03369c8 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -79,6 +79,7 @@ lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil  lemZeroInfoD2 STNil = Refl  lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl  lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl  lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl  lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl  lemZeroInfoD2 (STScal STI32) = Refl @@ -87,7 +88,6 @@ lemZeroInfoD2 (STScal STF32) = Refl  lemZeroInfoD2 (STScal STF64) = Refl  lemZeroInfoD2 (STScal STBool) = Refl  lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl  -- -- | Additional info needed for accumulation. This is empty unless there is  -- -- sparsity in the monoid. diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index e09f3ae..2bb78d4 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -388,6 +388,7 @@ ppSTy' :: Int -> STy t -> Doc q  ppSTy' _ STNil = ppString "1"  ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b  ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b  ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t  ppSTy' d (STArr n t) = ppParen (d > 10) $    ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t @@ -398,7 +399,6 @@ ppSTy' _ (STScal sty) = ppString $ case sty of    STF64 -> "f64"    STBool -> "bool"  ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t -ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b  ppSMTy :: Int -> SMTy t -> String  ppSMTy d ty = render $ ppSMTy' d ty diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 159934d..1379e35 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -123,11 +123,11 @@ split typ = case typ of    STPair{} -> splitRec (EVar ext typ IZ) typ    STNil -> other    STEither{} -> other +  STLEither{} -> other    STMaybe{} -> other    STArr{} -> other    STScal{} -> other    STAccum{} -> other -  STLEither{} -> other    where      other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])      other = (Point typ IZ, BTop) @@ -142,11 +142,11 @@ splitRec rhs typ = case typ of              (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b          in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)    STEither{} -> other +  STLEither{} -> other    STMaybe{} -> other    STArr{} -> other    STScal{} -> other    STAccum{} -> other -  STLEither{} -> other    where      other :: (Pointers (t : env) t, Bindings Ex env '[t])      other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index efb1e04..a3b7302 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -23,12 +23,11 @@ type data Ty    = TNil    | TPair Ty Ty    | TEither Ty Ty +  | TLEither Ty Ty    | TMaybe Ty    | TArr Nat Ty  -- ^ rank, element type    | TScal ScalTy    | TAccum Ty  -- ^ contained type must be a monoid type -  -- sparse monoid types -  | TLEither Ty Ty  type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -37,12 +36,11 @@ data STy t where    STNil :: STy TNil    STPair :: STy a -> STy b -> STy (TPair a b)    STEither :: STy a -> STy b -> STy (TEither a b) +  STLEither :: STy a -> STy b -> STy (TLEither a b)    STMaybe :: STy a -> STy (TMaybe a)    STArr :: SNat n -> STy t -> STy (TArr n t)    STScal :: SScalTy t -> STy (TScal t)    STAccum :: SMTy t -> STy (TAccum t) -  -- sparse monoid types -  STLEither :: STy a -> STy b -> STy (TLEither a b)  deriving instance Show (STy t)  instance GCompare STy where @@ -53,6 +51,8 @@ instance GCompare STy where      STPair{} _ -> GLT ; _ STPair{} -> GGT      (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')      STEither{} _ -> GLT ; _ STEither{} -> GGT +    (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') +    STLEither{} _ -> GLT ; _ STLEither{} -> GGT      (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')      STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT      (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') @@ -60,9 +60,7 @@ instance GCompare STy where      (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')      STScal{} _ -> GLT ; _ STScal{} -> GGT      (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') -    STAccum{} _ -> GLT ; _ STAccum{} -> GGT -    (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') -    -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT +    -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT  instance TestEquality STy where testEquality = geq  instance GEq STy where geq = defaultGeq @@ -177,11 +175,11 @@ hasArrays :: STy t' -> Bool  hasArrays STNil = False  hasArrays (STPair a b) = hasArrays a || hasArrays b  hasArrays (STEither a b) = hasArrays a || hasArrays b +hasArrays (STLEither a b) = hasArrays a || hasArrays b  hasArrays (STMaybe t) = hasArrays t  hasArrays STArr{} = True  hasArrays STScal{} = False  hasArrays STAccum{} = True -hasArrays (STLEither a b) = hasArrays a || hasArrays b  type family Tup env where    Tup '[] = TNil diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 20575b3..a1a6376 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -28,9 +28,9 @@ data ValId t where    VIPair :: ValId a -> ValId b -> ValId (TPair a b)    VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b)  -- ^ known alternative    VIEither' :: ValId a -> ValId b -> ValId (TEither a b)  -- ^ unknown alternative, but known values in each case +  VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)    VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)    VIMaybe' :: ValId a -> ValId (TMaybe a)  -- ^ if it's Just, it contains this value -  VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)    VIArr :: Int -> Vec n Int -> ValId (TArr n t)    VIScal :: Int -> ValId (TScal t)    VIAccum :: Int -> ValId (TAccum t) @@ -367,11 +367,11 @@ genIds :: STy t -> IdGen (ValId t)  genIds STNil = pure VINil  genIds (STPair a b) = VIPair <$> genIds a <*> genIds b  genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)  genIds (STMaybe t) = VIMaybe' <$> genIds t  genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId  genIds STScal{} = VIScal <$> genId  genIds STAccum{} = VIAccum <$> genId -genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)  shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)  shidsToVec SZ _ = pure VNil diff --git a/src/CHAD.hs b/src/CHAD.hs index ac308ac..3a7b907 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -484,6 +484,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of        STNil -> True        STPair a b -> isDiscrete a && isDiscrete b        STEither a b -> isDiscrete a && isDiscrete b +      STLEither a b -> isDiscrete a && isDiscrete b        STMaybe a -> isDiscrete a        STArr _ a -> isDiscrete a        STScal st -> case st of @@ -493,7 +494,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of          STF64 -> False          STBool -> True        STAccum{} -> False -      STLEither a b -> isDiscrete a && isDiscrete b  ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 9e7e7f5..261ddfe 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -49,11 +49,11 @@ d1Identity = \case    STNil -> Refl    STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl    STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl +  STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl    STMaybe t | Refl <- d1Identity t -> Refl    STArr _ t | Refl <- d1Identity t -> Refl    STScal _ -> Refl    STAccum{} -> error "Accumulators not allowed in input program" -  STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl  d1eIdentity :: SList STy env -> D1E env :~: env  d1eIdentity SNil = Refl diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 74e7dbd..974669d 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -11,19 +11,19 @@ type family D1 t where    D1 TNil = TNil    D1 (TPair a b) = TPair (D1 a) (D1 b)    D1 (TEither a b) = TEither (D1 a) (D1 b) +  D1 (TLEither a b) = TLEither (D1 a) (D1 b)    D1 (TMaybe a) = TMaybe (D1 a)    D1 (TArr n t) = TArr n (D1 t)    D1 (TScal t) = TScal t -  D1 (TLEither a b) = TLEither (D1 a) (D1 b)  type family D2 t where    D2 TNil = TNil    D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))    D2 (TEither a b) = TLEither (D2 a) (D2 b) +  D2 (TLEither a b) = TLEither (D2 a) (D2 b)    D2 (TMaybe t) = TMaybe (D2 t)    D2 (TArr n t) = TMaybe (TArr n (D2 t))    D2 (TScal t) = D2s t -  D2 (TLEither a b) = TLEither (D2 a) (D2 b)  type family D2s t where    D2s TI32 = TNil @@ -48,11 +48,11 @@ d1 :: STy t -> STy (D1 t)  d1 STNil = STNil  d1 (STPair a b) = STPair (d1 a) (d1 b)  d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b)  d1 (STMaybe t) = STMaybe (d1 t)  d1 (STArr n t) = STArr n (d1 t)  d1 (STScal t) = STScal t  d1 STAccum{} = error "Accumulators not allowed in input program" -d1 (STLEither a b) = STLEither (d1 a) (d1 b)  d1e :: SList STy env -> SList STy (D1E env)  d1e SNil = SNil @@ -62,6 +62,7 @@ d2M :: STy t -> SMTy (D2 t)  d2M STNil = SMTNil  d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))  d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)  d2M (STMaybe t) = SMTMaybe (d2M t)  d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))  d2M (STScal t) = case t of @@ -71,7 +72,6 @@ d2M (STScal t) = case t of    STF64 -> SMTScal STF64    STBool -> SMTNil  d2M STAccum{} = error "Accumulators not allowed in input program" -d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)  d2 :: STy t -> STy (D2 t)  d2 = fromSMTy . d2M diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 87c01cb..8476712 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -28,6 +28,11 @@ toTan typ primal der = case typ of                          (Left p, Left d') -> Left (toTan t1 p d')                          (Right p, Right d') -> Right (toTan t2 p d')                          _ -> error "Primal and cotangent disagree on Either alternative" +  STLEither t1 t2 -> case (primal, der) of +                       (_, Nothing) -> Nothing +                       (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) +                       (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) +                       _ -> error "Primal and cotangent disagree on LEither alternative"    STMaybe t -> liftA2 (toTan t) primal der    STArr _ t -> case der of                   Nothing -> arrayMap (zeroTan t) primal @@ -40,8 +45,3 @@ toTan typ primal der = case typ of    STScal sty -> case sty of      STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der    STAccum{} -> error "Accumulators not allowed in input program" -  STLEither t1 t2 -> case (primal, der) of -                       (_, Nothing) -> Nothing -                       (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) -                       (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) -                       _ -> error "Primal and cotangent disagree on LEither alternative" diff --git a/src/Compile.hs b/src/Compile.hs index 6ba3a39..cd10831 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -221,6 +221,7 @@ genStructName = \t -> "ty_" ++ gen t where    gen STNil = "n"    gen (STPair a b) = 'P' : gen a ++ gen b    gen (STEither a b) = 'E' : gen a ++ gen b +  gen (STLEither a b) = 'L' : gen a ++ gen b    gen (STMaybe t) = 'M' : gen t    gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t    gen (STScal st) = case st of @@ -230,7 +231,6 @@ genStructName = \t -> "ty_" ++ gen t where      STF64 -> "d"      STBool -> "b"    gen (STAccum t) = 'C' : gen (fromSMTy t) -  gen (STLEither a b) = 'L' : gen a ++ gen b  -- | This function generates the actual struct declarations for each of the  -- types in our language. It thus implicitly "documents" the layout of the @@ -246,6 +246,8 @@ genStruct name topty = case topty of      [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]    STEither a b ->  -- 0 -> l, 1 -> r      [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] +  STLEither a b ->  -- 0 -> nil, 1 -> l, 2 -> r +    [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]    STMaybe t ->  -- 0 -> nothing, 1 -> just      [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]    STArr n t -> @@ -259,8 +261,6 @@ genStruct name topty = case topty of    STAccum t ->      [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") ""      ,StructDecl name (name ++ "_buf *buf;") com] -  STLEither a b ->  -- 0 -> nil, 1 -> l, 2 -> r -    [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]    where      com = ppSTy 0 topty @@ -282,11 +282,11 @@ genStructs ty = do          STNil -> pure ()          STPair a b -> genStructs a >> genStructs b          STEither a b -> genStructs a >> genStructs b +        STLEither a b -> genStructs a >> genStructs b          STMaybe t -> genStructs t          STArr _ t -> genStructs t          STScal _ -> pure ()          STAccum t -> genStructs (fromSMTy t) -        STLEither a b -> genStructs a >> genStructs b        tell (BList (genStruct name ty)) @@ -463,6 +463,15 @@ serialise topty topval ptr off k =      (STEither _ b, Right y) -> do        pokeByteOff ptr off (1 :: Word8)        serialise b y ptr (off + alignmentSTy topty) k +    (STLEither _ _, Nothing) -> do +      pokeByteOff ptr off (0 :: Word8) +      k +    (STLEither a _, Just (Left x)) -> do +      pokeByteOff ptr off (1 :: Word8)  -- alignment of (union {a b}) is the same as alignment of (1 + a + b) +      serialise a x ptr (off + alignmentSTy topty) k +    (STLEither _ b, Just (Right y)) -> do +      pokeByteOff ptr off (2 :: Word8) +      serialise b y ptr (off + alignmentSTy topty) k      (STMaybe _, Nothing) -> do        pokeByteOff ptr off (0 :: Word8)        k @@ -493,15 +502,6 @@ serialise topty topval ptr off k =        STF64 -> pokeByteOff ptr off (x :: Double) >> k        STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k      (STAccum{}, _) -> error "Cannot serialise accumulators" -    (STLEither _ _, Nothing) -> do -      pokeByteOff ptr off (0 :: Word8) -      k -    (STLEither a _, Just (Left x)) -> do -      pokeByteOff ptr off (1 :: Word8)  -- alignment of (union {a b}) is the same as alignment of (1 + a + b) -      serialise a x ptr (off + alignmentSTy topty) k -    (STLEither _ b, Just (Right y)) -> do -      pokeByteOff ptr off (2 :: Word8) -      serialise b y ptr (off + alignmentSTy topty) k  -- | Assumes that this is called at the correct alignment.  deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -518,6 +518,13 @@ deserialise topty ptr off =        if tag == 0  -- alignment of (union {a b}) is the same as alignment of (a + b)          then Left <$> deserialise a ptr (off + alignmentSTy topty)          else Right <$> deserialise b ptr (off + alignmentSTy topty) +    STLEither a b -> do +      tag <- peekByteOff @Word8 ptr off +      case tag of  -- alignment of (union {a b}) is the same as alignment of (a + b) +        0 -> return Nothing +        1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) +        2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) +        _ -> error "Invalid tag value"      STMaybe t -> do        tag <- peekByteOff @Word8 ptr off        if tag == 0 @@ -541,13 +548,6 @@ deserialise topty ptr off =        STF64 -> peekByteOff @Double ptr off        STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off      STAccum{} -> error "Cannot serialise accumulators" -    STLEither a b -> do -      tag <- peekByteOff @Word8 ptr off -      case tag of  -- alignment of (union {a b}) is the same as alignment of (a + b) -        0 -> return Nothing -        1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) -        2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) -        _ -> error "Invalid tag value"  align :: Int -> Int -> Int  align a off = (off + a - 1) `div` a * a @@ -569,6 +569,10 @@ metricsSTy (STEither a b) =    let (a1, s1) = metricsSTy a        (a2, s2) = metricsSTy b    in (max a1 a2, max a1 a2 + max s1 s2)  -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = +  let (a1, s1) = metricsSTy a +      (a2, s2) = metricsSTy b +  in (max a1 a2, max a1 a2 + max s1 s2)  -- the union after the tag byte is aligned  metricsSTy (STMaybe t) =    let (a, s) = metricsSTy t    in (a, a + s)  -- the union after the tag byte is aligned @@ -580,10 +584,6 @@ metricsSTy (STScal sty) = case sty of    STF64 -> (8, 8)    STBool -> (1, 1)  -- compiled to uint8_t  metricsSTy (STAccum t) = metricsSTy (fromSMTy t) -metricsSTy (STLEither a b) = -  let (a1, s1) = metricsSTy a -      (a2, s2) = metricsSTy b -  in (max a1 a2, max a1 a2 + max s1 s2)  -- the union after the tag byte is aligned  pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()  pokeShape ptr off = go . fromSNat @@ -1071,8 +1071,8 @@ compile' env = \case                                              incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend                                              emit $ SAsg v (CELit addend)            -- sparse types -          (SMTLEither{} , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")            (SMTMaybe{}   , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") +          (SMTLEither{} , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")            -- dense types            (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do                                              f (v++".a") (i++".a") @@ -1303,13 +1303,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a))                                           (smartATProj "b" (makeArrayTree b))  makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a))                                                (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop +                                                (smartATProj "l" (makeArrayTree a)) +                                                (smartATProj "r" (makeArrayTree b))  makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t))  makeArrayTree (STArr n t) = ATArray (Some n) (Some t)  makeArrayTree (STScal _) = ATNoop  makeArrayTree (STAccum _) = ATNoop -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop -                                                (smartATProj "l" (makeArrayTree a)) -                                                (smartATProj "r" (makeArrayTree b))  incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM ()  incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1657,6 +1657,14 @@ zeroRefcountCheck toptyp opname topvar =      go (STEither a b) path = do        (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))        return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 +    go (STLEither a b) path = do +      (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) +      return $ pure $ +        SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) +          s1 +          (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) +                   s2 +                   mempty))      go (STMaybe a) path = do        ss <- go a (path++".j")        return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty @@ -1673,14 +1681,6 @@ zeroRefcountCheck toptyp opname topvar =        return (BList [s1, s2, s3])      go STScal{} _ = empty      go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" -    go (STLEither a b) path = do -      (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) -      return $ pure $ -        SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) -          s1 -          (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) -                   s2 -                   mempty))      combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)      combine (MaybeT a) (MaybeT b) = MaybeT $ do diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 5756f96..b353def 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -26,10 +26,10 @@ type family Tan t where    Tan TNil = TNil    Tan (TPair a b) = TPair (Tan a) (Tan b)    Tan (TEither a b) = TEither (Tan a) (Tan b) +  Tan (TLEither a b) = TLEither (Tan a) (Tan b)    Tan (TMaybe t) = TMaybe (Tan t)    Tan (TArr n t) = TArr n (Tan t)    Tan (TScal t) = TanS t -  Tan (TLEither a b) = TLEither (Tan a) (Tan b)  type family TanS t where    TanS TI32 = TNil @@ -46,6 +46,7 @@ tanty :: STy t -> STy (Tan t)  tanty STNil = STNil  tanty (STPair a b) = STPair (tanty a) (tanty b)  tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b)  tanty (STMaybe t) = STMaybe (tanty t)  tanty (STArr n t) = STArr n (tanty t)  tanty (STScal t) = case t of @@ -55,7 +56,6 @@ tanty (STScal t) = case t of    STF64 -> STScal STF64    STBool -> STNil  tanty STAccum{} = error "Accumulators not allowed in input program" -tanty (STLEither a b) = STLEither (tanty a) (tanty b)  tanenv :: SList STy env -> SList STy (TanE env)  tanenv SNil = SNil @@ -66,6 +66,9 @@ zeroTan STNil () = ()  zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)  zeroTan (STEither a _) (Left x) = Left (zeroTan a x)  zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))  zeroTan (STMaybe _) Nothing = Nothing  zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)  zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,15 +78,15 @@ zeroTan (STScal STF32) _ = 0.0  zeroTan (STScal STF64) _ = 0.0  zeroTan (STScal STBool) _ = ()  zeroTan STAccum{} _ = error "Accumulators not allowed in input program" -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))  tanScalars :: STy t -> Rep (Tan t) -> [Double]  tanScalars STNil () = []  tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y  tanScalars (STEither a _) (Left x) = tanScalars a x  tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y  tanScalars (STMaybe _) Nothing = []  tanScalars (STMaybe t) (Just x) = tanScalars t x  tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x @@ -93,9 +96,6 @@ tanScalars (STScal STF32) x = [realToFrac x]  tanScalars (STScal STF64) x = [x]  tanScalars (STScal STBool) _ = []  tanScalars STAccum{} _ = error "Accumulators not allowed in input program" -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y  tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]  tanEScalars SNil SNil = [] @@ -110,6 +110,10 @@ unzipDN (STPair a b) (d1, d2) =  unzipDN (STEither a b) d = case d of    Left d1 -> bimap Left Left (unzipDN a d1)    Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of +  Nothing -> (Nothing, Nothing) +  Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) +  Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)  unzipDN (STMaybe t) d = case d of    Nothing -> (Nothing, Nothing)    Just d' -> bimap Just Just (unzipDN t d') @@ -123,10 +127,6 @@ unzipDN (STScal ty) d = case ty of    STF64 -> d    STBool -> (d, ())  unzipDN STAccum{} _ = error "Accumulators not allowed in input program" -unzipDN (STLEither a b) d = case d of -  Nothing -> (Nothing, Nothing) -  Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) -  Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)  dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double  dotprodTan STNil _ _ = 0.0 @@ -136,6 +136,12 @@ dotprodTan (STEither a b) x y = case (x, y) of    (Left x', Left y') -> dotprodTan a x' y'    (Right x', Right y') -> dotprodTan b x' y'    _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of +  (Nothing, _) -> 0.0  -- 0 * y = 0 +  (_, Nothing) -> 0.0  -- x * 0 = 0 +  (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' +  (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' +  _ -> error "dotprodTan: incompatible LEither alternatives"  dotprodTan (STMaybe t) x y = case (x, y) of    (Nothing, Nothing) -> 0.0    (Just x', Just y') -> dotprodTan t x' y' @@ -153,12 +159,6 @@ dotprodTan (STScal ty) x y = case ty of    STF64 -> x * y    STBool -> 0.0  dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -dotprodTan (STLEither a b) x y = case (x, y) of -  (Nothing, _) -> 0.0  -- 0 * y = 0 -  (_, Nothing) -> 0.0  -- x * 0 = 0 -  (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' -  (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' -  _ -> error "dotprodTan: incompatible LEither alternatives"  -- -- Primal expression must be duplicable  -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -187,6 +187,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t)  dnConst STNil = const ()  dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)  dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))  dnConst (STMaybe t) = fmap (dnConst t)  dnConst (STArr _ t) = arrayMap (dnConst t)  dnConst (STScal t) = case t of @@ -196,7 +197,6 @@ dnConst (STScal t) = case t of    STF64 -> (,0.0)    STBool -> id  dnConst STAccum{} = error "Accumulators not allowed in input program" -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))  -- | Given a function that computes the forward derivative for a particular  -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -211,6 +211,11 @@ dnOnehots (STEither t1 t2) e =    case e of      Left x -> \f -> Left (dnOnehots t1 x (f . Left))      Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = +  case e of +    Nothing -> \_ -> Nothing +    Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) +    Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))  dnOnehots (STMaybe t) m =    case m of      Nothing -> \_ -> Nothing @@ -227,11 +232,6 @@ dnOnehots (STScal t) x = case t of    STF64 -> \f -> f (x, 1.0)    STBool -> \_ -> ()  dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" -dnOnehots (STLEither t1 t2) e = -  case e of -    Nothing -> \_ -> Nothing -    Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) -    Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))  dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)  dnConstEnv SNil SNil = SNil diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs index 3c76cbe..dcacf5f 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -12,10 +12,10 @@ type family DN t where    DN TNil = TNil    DN (TPair a b) = TPair (DN a) (DN b)    DN (TEither a b) = TEither (DN a) (DN b) +  DN (TLEither a b) = TLEither (DN a) (DN b)    DN (TMaybe t) = TMaybe (DN t)    DN (TArr n t) = TArr n (DN t)    DN (TScal t) = DNS t -  DN (TLEither a b) = TLEither (DN a) (DN b)  type family DNS t where    DNS TF32 = TPair (TScal TF32) (TScal TF32) @@ -32,6 +32,7 @@ dn :: STy t -> STy (DN t)  dn STNil = STNil  dn (STPair a b) = STPair (dn a) (dn b)  dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b)  dn (STMaybe t) = STMaybe (dn t)  dn (STArr n t) = STArr n (dn t)  dn (STScal t) = case t of @@ -41,7 +42,6 @@ dn (STScal t) = case t of    STI64 -> STScal STI64    STBool -> STScal STBool  dn STAccum{} = error "Accum in source program" -dn (STLEither a b) = STLEither (dn a) (dn b)  dne :: SList STy env -> SList STy (DNE env)  dne SNil = SNil diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 070ba4c..1682303 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -22,11 +22,11 @@ type family Rep t where    Rep TNil = ()    Rep (TPair a b) = (Rep a, Rep b)    Rep (TEither a b) = Either (Rep a) (Rep b) +  Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))    Rep (TMaybe t) = Maybe (Rep t)    Rep (TArr n t) = Array n (Rep t)    Rep (TScal sty) = ScalRep sty    Rep (TAccum t) = RepAc t -  Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b))  -- Mutable, represents monoid types t.  type family RepAc t where @@ -56,6 +56,9 @@ showValue _ STNil () = showString "()"  showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"  showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x  showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y  showValue _ (STMaybe _) Nothing = showString "Nothing"  showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x  showValue d (STArr _ t) arr = showParen (d > 10) $ @@ -70,9 +73,6 @@ showValue d (STScal sty) x = case sty of    STI64 -> showsPrec d x    STBool -> showsPrec d x  showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">" -showValue _ (STLEither _ _) Nothing = showString "LNil" -showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x -showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y  showEnv :: SList STy env -> SList Value env -> String  showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" @@ -86,6 +86,9 @@ rnfRep STNil () = ()  rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y  rnfRep (STEither a _) (Left x) = rnfRep a x  rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y  rnfRep (STMaybe _) Nothing = ()  rnfRep (STMaybe t) (Just x) = rnfRep t x  rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = @@ -97,9 +100,6 @@ rnfRep (STScal t) x = case t of    STF64 -> rnf x    STBool -> rnf x  rnfRep STAccum{} _ = error "Cannot rnf accumulators" -rnfRep (STLEither _ _) Nothing = () -rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x -rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y  instance KnownTy t => NFData (Value t) where    rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/Simplify.hs b/src/Simplify.hs index f5eb0a1..6f97e6d 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -359,11 +359,11 @@ checkAccumInScope = \case SNil -> False      check STNil = False      check (STPair s t) = check s || check t      check (STEither s t) = check s || check t +    check (STLEither s t) = check s || check t      check (STMaybe t) = check t      check (STArr _ t) = check t      check (STScal _) = False      check STAccum{} = True -    check (STLEither s t) = check s || check t  data OneHotTerm env p a b where    OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b | 
