diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/AST.hs | 84 | ||||
-rw-r--r-- | src/AST/Accum.hs | 90 | ||||
-rw-r--r-- | src/AST/Count.hs | 6 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 76 | ||||
-rw-r--r-- | src/AST/SplitLets.hs | 26 | ||||
-rw-r--r-- | src/AST/Types.hs | 51 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 145 | ||||
-rw-r--r-- | src/Analysis/Identity.hs | 59 | ||||
-rw-r--r-- | src/CHAD.hs | 83 | ||||
-rw-r--r-- | src/CHAD/Accum.hs | 4 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 3 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 44 | ||||
-rw-r--r-- | src/Compile.hs | 441 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 4 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers/Types.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 25 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 11 | ||||
-rw-r--r-- | src/Language.hs | 8 | ||||
-rw-r--r-- | src/Language/AST.hs | 4 | ||||
-rw-r--r-- | src/Simplify.hs | 138 |
20 files changed, 869 insertions, 435 deletions
@@ -16,6 +16,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const @@ -33,11 +34,9 @@ import Data -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. -- --- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the --- type transformation of CHAD. Indeed, these constructors are created _by_ --- CHAD, and are intended to be eliminated after simplification, so that the --- input program as well as the output program do not contain these --- constructors. +-- All the monoid operations are unsupposed as the input to CHAD, and are +-- intended to be eliminated after simplification, so that the input program as +-- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where @@ -56,6 +55,10 @@ data Expr x env t where ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -88,13 +91,13 @@ data Expr x env t where -> Expr x env t -- accumulation effect on monoids - EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) - EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil + EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) - EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) + EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t -- partiality EError :: x a -> STy a -> String -> Expr x env a @@ -184,6 +187,10 @@ typeOf = \case ENothing _ t -> STMaybe t EJust _ e -> STMaybe (typeOf e) EMaybe _ e _ _ -> typeOf e + ELNil _ t1 t2 -> STLEither t1 t2 + ELInl _ t2 e -> STLEither (typeOf e) t2 + ELInr _ t1 e -> STLEither t1 (typeOf e) + ELCase _ _ a _ _ -> typeOf a EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) @@ -206,9 +213,9 @@ typeOf = \case EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ _ _ -> STNil - EZero _ t -> d2 t - EPlus _ t _ _ -> d2 t - EOneHot _ t _ _ _ -> d2 t + EZero _ t _ -> fromSMTy t + EPlus _ t _ _ -> fromSMTy t + EOneHot _ t _ _ _ -> fromSMTy t EError _ t _ -> t @@ -226,6 +233,10 @@ extOf = \case ENothing x _ -> x EJust x _ -> x EMaybe x _ _ _ -> x + ELNil x _ _ -> x + ELInl x _ _ -> x + ELInr x _ _ -> x + ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x EFold1Inner x _ _ _ _ -> x @@ -243,7 +254,7 @@ extOf = \case ECustom x _ _ _ _ _ _ _ _ -> x EWith x _ _ _ -> x EAccum x _ _ _ _ _ -> x - EZero x _ -> x + EZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x @@ -262,6 +273,10 @@ mapExt f = \case ENothing x t -> ENothing (f x) t EJust x e -> EJust (f x) (mapExt f e) EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) + ELNil x t1 t2 -> ELNil (f x) t1 t2 + ELInl x t e -> ELInl (f x) t (mapExt f e) + ELInr x t e -> ELInr (f x) t (mapExt f e) + ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c) EConstArr x n t a -> EConstArr (f x) n t a EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) @@ -279,7 +294,7 @@ mapExt f = \case ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) - EZero x t -> EZero (f x) t + EZero x t e -> EZero (f x) t (mapExt f e) EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) EError x t s -> EError (f x) t s @@ -315,6 +330,10 @@ subst' f w = \case ENothing x t -> ENothing x t EJust x e -> EJust x (subst' f w e) EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (subst' f w e) + ELInr x t e -> ELInr x t (subst' f w e) + ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) @@ -332,9 +351,9 @@ subst' f w = \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) - EZero x t -> EZero x t + EZero x t e -> EZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -360,7 +379,16 @@ instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEithe instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil @@ -373,7 +401,16 @@ styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- styKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict @@ -451,3 +488,12 @@ eshapeEmpty (SS n) e = (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + +ezeroD2 :: STy t -> Ex env (D2 t) +ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) + +-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil +-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea + +-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) +-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 67c5de7..e84034b 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -8,6 +8,7 @@ module AST.Accum where import AST.Types +import CHAD.Types import Data @@ -26,35 +27,90 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPHere :: SAcPrj APHere a a SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - -- TODO: This SNat is rather useless, you always have an STy around too - SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b + SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b -- TODO: -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) type family AcIdx p t where AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = AcIdx p a - AcIdx (APSnd p) (TPair a b) = AcIdx p b - AcIdx (APLeft p) (TEither a b) = AcIdx p a - AcIdx (APRight p) (TEither a b) = AcIdx p b + AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) + AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) + AcIdx (APLeft p) (TLEither a b) = AcIdx p a + AcIdx (APRight p) (TLEither a b) = AcIdx p b AcIdx (APJust p) (TMaybe a) = AcIdx p a AcIdx (APArrIdx p) (TArr n a) = - -- ((index, array shape), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) + -- ((index, shapes info), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) (AcIdx p a) -- AcIdx (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -acPrjTy :: SAcPrj p a b -> STy a -> STy b +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil +lemZeroInfoD2 STNil = Refl +lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl +lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl +lemZeroInfoD2 (STScal STI32) = Refl +lemZeroInfoD2 (STScal STI64) = Refl +lemZeroInfoD2 (STScal STF32) = Refl +lemZeroInfoD2 (STScal STF64) = Refl +lemZeroInfoD2 (STScal STBool) = Refl +lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" +lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +-- AccumInfo TNil = TNil +-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +-- AccumInfo (TArr n t) = TArr n (AccumInfo t) +-- AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +-- PrimalInfo TNil = TNil +-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +-- PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil diff --git a/src/AST/Count.hs b/src/AST/Count.hs index dc8ec72..feaaa1e 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -113,6 +113,10 @@ occCountGeneral onehot unpush alter many = go WId ENothing _ _ -> mempty EJust _ e -> re e EMaybe _ a b e -> re a <> re1 b <> re e + ELNil _ _ _ -> mempty + ELInl _ _ e -> re e + ELInr _ _ e -> re e + ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c) EConstArr{} -> mempty EBuild _ _ a b -> re a <> many (re1 b) EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c @@ -130,7 +134,7 @@ occCountGeneral onehot unpush alter many = go WId ECustom _ _ _ _ _ _ _ a b -> re a <> re b EWith _ _ a b -> re a <> re1 b EAccum _ _ _ a b e -> re a <> re b <> re e - EZero _ _ -> mempty + EZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index fb5e138..b6ad7d2 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -7,7 +7,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where +module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) @@ -152,6 +152,31 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e'] + ELNil _ _ _ -> return (ppString "LNil") + + ELInl _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' + + ELInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' + + ELCase _ e a b c -> do + e' <- ppExpr' 0 val e + let STLEither t1 t2 = typeOf e + a' <- ppExpr' 11 val a + name1 <- genNameIfUsedIn t1 IZ b + b' <- ppExpr' 0 (Const name1 `SCons` val) b + name2 <- genNameIfUsedIn t2 IZ c + c' <- ppExpr' 0 (Const name2 `SCons` val) c + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "LNil" <+> ppString "->" <+> a' + <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' + <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' + EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -267,15 +292,17 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ _ prj e1 e2 e3 -> do + EAccum _ t prj e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj t prj), e1', e2', e3'] - EZero _ t -> return $ ppParen (d > 0) $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t + EZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' EPlus _ _ a b -> do a' <- ppExpr' 11 val a @@ -283,11 +310,11 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] - EOneHot _ _ prj a b -> do + EOneHot _ t prj a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b'] + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj t prj), a', b'] EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -320,14 +347,14 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") -ppAcPrj :: SAcPrj p a b -> String -ppAcPrj SAPHere = "@" -ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" -ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" -ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj -ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +ppAcPrj :: SMTy a -> SAcPrj p a b -> String +ppAcPrj _ SAPHere = "@" +ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" +ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" +ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj +ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -370,7 +397,24 @@ ppSTy' _ (STScal sty) = ppString $ case sty of STF32 -> "f32" STF64 -> "f64" STBool -> "bool" -ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b + +ppSMTy :: Int -> SMTy t -> String +ppSMTy d ty = render $ ppSMTy' d ty + +ppSMTy' :: Int -> SMTy t -> Doc q +ppSMTy' _ SMTNil = ppString "1" +ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b +ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b +ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t +ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t +ppSMTy' _ (SMTScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" ppString :: String -> Doc x ppString = fromString diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index dcba1ad..159934d 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -29,6 +29,9 @@ splitLets' = \sub -> \case EMaybe x a b e -> let STMaybe t1 = typeOf e in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) + ELCase x e a b c -> + let STLEither t1 t2 = typeOf e + in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) @@ -41,6 +44,9 @@ splitLets' = \sub -> \case EInr x t e -> EInr x t (splitLets' sub e) ENothing x t -> ENothing x t EJust x e -> EJust x (splitLets' sub e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (splitLets' sub e) + ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) @@ -57,7 +63,7 @@ splitLets' = \sub -> \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) - EZero x t -> EZero x t + EZero x t ezi -> EZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s @@ -121,24 +127,26 @@ split typ = case typ of STArr{} -> other STScal{} -> other STAccum{} -> other + STLEither{} -> other where other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) other = (Point typ IZ, BTop) splitRec :: forall env t. Ex env t -> STy t -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs = \case +splitRec rhs typ = case typ of STNil -> (PNil, BTop) STPair (a :: STy a) (b :: STy b) | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> let (p1, bs1) = splitRec (EFst ext rhs) a (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) - t@STEither{} -> other t - t@STMaybe{} -> other t - t@STArr{} -> other t - t@STScal{} -> other t - t@STAccum{} -> other t + STEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other + STLEither{} -> other where - other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) - other t = (Point t IZ, BPush BTop (t, rhs)) + other :: (Pointers (t : env) t, Bindings Ex env '[t]) + other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index b20fc2d..c8515fc 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -27,6 +27,8 @@ type data Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Ty -- ^ contained type must be a monoid type + -- sparse monoid types + | TLEither Ty Ty type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -38,7 +40,9 @@ data STy t where STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) - STAccum :: STy t -> STy (TAccum t) + STAccum :: SMTy t -> STy (TAccum t) + -- sparse monoid types + STLEither :: STy a -> STy b -> STy (TLEither a b) deriving instance Show (STy t) instance GCompare STy where @@ -56,12 +60,54 @@ instance GCompare STy where (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + STAccum{} _ -> GLT ; _ STAccum{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq instance GShow STy where gshowsPrec = defaultGshowsPrec +-- | Monoid types +type SMTy :: Ty -> Type +data SMTy t where + SMTNil :: SMTy TNil + SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) + -- TODO: call this SMTLEither + SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) + SMTMaybe :: SMTy a -> SMTy (TMaybe a) + SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) + SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) +deriving instance Show (SMTy t) + +instance GCompare SMTy where + gcompare = \cases + SMTNil SMTNil -> GEQ + SMTNil _ -> GLT ; _ SMTNil -> GGT + (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT + (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT + (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') + SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT + (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT + (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') + -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + +instance TestEquality SMTy where testEquality = geq +instance GEq SMTy where geq = defaultGeq +instance GShow SMTy where gshowsPrec = defaultGshowsPrec + +fromSMTy :: SMTy t -> STy t +fromSMTy = \case + SMTNil -> STNil + SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) + SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) + SMTMaybe t -> STMaybe (fromSMTy t) + SMTArr n t -> STArr n (fromSMTy t) + SMTScal sty -> STScal sty + data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 @@ -136,6 +182,7 @@ hasArrays (STMaybe t) = hasArrays t hasArrays STArr{} = True hasArrays STScal{} = False hasArrays STAccum{} = True +hasArrays (STLEither a b) = hasArrays a || hasArrays b type family Tup env where Tup '[] = TNil diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 0da1afc..3d5f544 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -5,13 +5,14 @@ module AST.UnMonoid (unMonoid, zero, plus) where import AST -import CHAD.Types import Data +-- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them +-- into their concrete implementations. unMonoid :: Ex env t -> Ex env t unMonoid = \case - EZero _ t -> zero t + EZero _ t e -> zero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -27,6 +28,10 @@ unMonoid = \case ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unMonoid e) + ELInr _ t e -> ELInr ext t (unMonoid e) + ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) @@ -46,92 +51,94 @@ unMonoid = \case EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s -zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) -zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) -zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) -zero (STArr n t) = ENothing ext (STArr n (d2 t)) -zero (STScal t) = case t of - STI32 -> ENil ext - STI64 -> ENil ext +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +zero SMTNil _ = ENil ext +zero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) +zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e +zero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 - STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" -plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = - let t = STPair (d2 t1) (d2 t2) - in plusSparse t a b $ +plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t +plus SMTNil _ _ = ENil ext +plus (SMTPair t1 t2) a b = + let t = STPair (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) (EFst ext (EVar ext t IZ))) (plus t2 (ESnd ext (EVar ext t (IS IZ))) (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = - let t = STEither (d2 t1) (d2 t2) - in plusSparse t a b $ - ECase ext (EVar ext t (IS IZ)) - (ECase ext (EVar ext t (IS IZ)) - (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) +plus (SMTLEither t1 t2) a b = + let t = STLEither (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + ELCase ext (EVar ext t (IS IZ)) + (EVar ext t IZ) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) (EError ext t "plus l+r")) - (ECase ext (EVar ext t (IS IZ)) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) (EError ext t "plus r+l") - (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus (STMaybe t) a b = - plusSparse (d2 t) a b $ - plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) -plus (STArr n t) a b = - plusSparse (STArr n (d2 t)) a b $ - eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ) - (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) - (EVar ext (STArr n (d2 t)) (IS IZ)) - (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) - (EVar ext (STArr n (d2 t)) (IS IZ)) - (EVar ext (STArr n (d2 t)) IZ))) -plus (STScal t) a b = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EOp ext (OAdd STF32) (EPair ext a b) - STF64 -> EOp ext (OAdd STF64) (EPair ext a b) - STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" - -plusSparse :: STy a - -> Ex env (TMaybe a) -> Ex env (TMaybe a) - -> Ex (a : a : env) a - -> Ex env (TMaybe a) -plusSparse t a b adder = + (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) +plus (SMTMaybe t) a b = ELet ext b $ EMaybe ext - (EVar ext (STMaybe t) IZ) + (EVar ext (STMaybe (fromSMTy t)) IZ) (EJust ext (EMaybe ext - (EVar ext t IZ) - (weakenExpr (WCopy (WCopy WSink)) adder) - (EVar ext (STMaybe t) (IS IZ)))) + (EVar ext (fromSMTy t) IZ) + (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) (weakenExpr WSink a) +plus (SMTArr _ t) a b = + ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + a b +plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> arg + (_, SAPHere) -> + ELet ext arg $ + EVar ext (fromSMTy typ) IZ - (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) - (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) + (SMTPair t1 t2, SAPFst prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t1 in + EPair ext (EVar ext toh IZ) + (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) + + (SMTPair t1 t2, SAPSnd prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t2 in + EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) + (EVar ext toh IZ) - (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) - (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) + (SMTLEither t1 t2, SAPLeft prj) -> + ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + (SMTLEither t1 t2, SAPRight prj) -> + ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) - (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) + (SMTMaybe t1, SAPJust prj) -> + EJust ext (onehot t1 prj idx arg) - (STArr n t1, SAPArrIdx prj _) -> + (SMTArr n t1, SAPArrIdx prj) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ - EJust ext $ - EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (zero t1) + EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ + zero t1 (EVar ext (tZeroInfo t1) IZ)) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index f34bfbc..20575b3 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -30,6 +30,7 @@ data ValId t where VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) @@ -45,6 +46,13 @@ instance PrettyX ValId where VIMaybe Nothing -> "N" VIMaybe (Just a) -> 'J' : prettyX a VIMaybe' a -> 'M' : prettyX a + VILEither (VIMaybe Nothing) -> "lN" + VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" + VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))" VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" VIScal i -> show i VIAccum i -> 'C' : show i @@ -147,6 +155,42 @@ idana env expr = case expr of res <- unify v1 v2 pure (res, EMaybe res e1' e2' e3') + ELNil _ t1 t2 -> do + let v = VILEither (VIMaybe Nothing) + pure (v, ELNil v t1 t2) + + ELInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) + pure (v, ELInl v t2 e1') + + ELInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) + pure (v, ELInr v t1 e2') + + ELCase _ e1 e2 e3 e4 -> do + let STLEither t1 t2 = typeOf e1 + (v1L, e1') <- idana env e1 + let VILEither v1 = v1L + let go mv1'l mv1'r f = do + v1'l <- maybe (genIds t1) pure mv1'l + v1'r <- maybe (genIds t2) pure mv1'r + (v2, e2') <- idana env e2 + (v3, e3') <- idana (v1'l `SCons` env) e3 + (v4, e4') <- idana (v1'r `SCons` env) e4 + res <- f v2 v3 v4 + pure (res, ELCase res e1' e2' e3' e4') + case v1 of + VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) + VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) + VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) + VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) + VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) + VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) + VIMaybe' (VIEither' v1'l v1'r) -> + go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) + EConstArr _ dim t arr -> do x1 <- VIArr <$> genId <*> vecReplicateA dim genId pure (x1, EConstArr x1 dim t arr) @@ -265,20 +309,23 @@ idana env expr = case expr of (_, e3') <- idana env e3 pure (VINil, EAccum VINil t prj e1' e2' e3') - EZero _ t -> do - res <- genIds (d2 t) - pure (res, EZero res t) + EZero _ t e1 -> do + -- Approximate the result of EZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EZero res t e1') EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (d2 t) + res <- genIds (fromSMTy t) pure (res, EPlus res t e1' e2') EOneHot _ t i e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (d2 t) + res <- genIds (fromSMTy t) pure (res, EOneHot res t i e1' e2') EError _ t s -> do @@ -307,6 +354,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VILEither a) (VILEither b) = VILEither <$> unify a b unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j @@ -323,6 +371,7 @@ genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) shidsToVec SZ _ = pure VNil diff --git a/src/CHAD.hs b/src/CHAD.hs index 1126fde..ac308ac 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -292,7 +292,7 @@ conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t)) + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) | Idx2Me (Idx (Select env sto "merge") t) | Idx2Di (Idx (Select env sto "discr") t) @@ -319,7 +319,7 @@ conv2Idx DTop i = case i of {} zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext -zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t) +zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) ------------------------------------ SUBENVS ----------------------------------- @@ -359,7 +359,7 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = ELet ext (weakenExpr WSink e2) $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext t + (EPlus ext (d2M t) (ESnd ext (EVar ext (typeOf e1) (IS IZ))) (ESnd ext (EVar ext (typeOf e2) IZ))) @@ -369,7 +369,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = ELet ext e $ let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t) +expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl @@ -425,11 +425,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of (SEYes accrevsub) (VarMap.sink1 accumMap) (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -453,7 +453,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of - Just i -> VarMap.insert i (STAccum t) IZ accumMap' + Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' Nothing -> accumMap') (\(shbinds :: SList _ shbinds) -> let shbindsC = slistMap (\_ -> Const ()) shbinds @@ -466,7 +466,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do @@ -493,6 +493,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STF64 -> False STBool -> True STAccum{} -> False + STLEither a b -> isDiscrete a && isDiscrete b ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- @@ -596,7 +597,7 @@ drev des accumMap = \case SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI))) + (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop @@ -666,7 +667,7 @@ drev des accumMap = \case subtape (EFst ext e1) sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $ + (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e @@ -676,7 +677,7 @@ drev des accumMap = \case subtape (ESnd ext e1) sub - (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) @@ -687,12 +688,11 @@ drev des accumMap = \case subtape (EInl ext (d1 t2) e1) sub - (EMaybe ext + (ELCase ext + (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) - (weakenExpr (WCopy (wSinks' @[_,_])) e2) + (weakenExpr (WCopy WSink) e2) (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) - (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> @@ -700,12 +700,11 @@ drev des accumMap = \case subtape (EInr ext (d1 t1) e1) sub - (EMaybe ext + (ELCase ext + (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) + (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") + (weakenExpr (WCopy WSink) e2)) ECase _ e (a :: Expr _ _ t) b | STEither t1 t2 <- typeOf e @@ -727,7 +726,7 @@ drev des accumMap = \case -> subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in + let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in Ret (e0 `BPush` (tPrimal, ECase ext e1 @@ -755,7 +754,7 @@ drev des accumMap = \case EPair ext (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (EInl ext (d2 t2) + (ELInl ext (d2 t2) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ in letBinds rebinds $ @@ -774,10 +773,10 @@ drev des accumMap = \case EPair ext (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (EInr ext (d2 t1) + (ELInr ext (d2 t1) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ ELet ext - (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $ + (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ plus_AB_E (EFst ext (EVar ext tCaseRet (IS IZ))) @@ -934,8 +933,8 @@ drev des accumMap = \case (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (ezeroD2 eltty) (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) @@ -975,6 +974,7 @@ drev des accumMap = \case <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n + , Refl <- lemZeroInfoD2 eltty , let tIxN = tTup (sreplicate n tIx) -> Ret (binds `BPush` (STArr n (d1 eltty), e1) `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) @@ -983,10 +983,11 @@ drev des accumMap = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ)))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ + (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) + (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) + (ENil ext)) + (EVar ext (d2 eltty) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e @@ -1026,6 +1027,10 @@ drev des accumMap = \case ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" + ELNil{} -> err_unsupported "ELNil" + ELInl{} -> err_unsupported "ELInl" + ELInr{} -> err_unsupported "ELInr" + ELCase{} -> err_unsupported "ELCase" EWith{} -> err_accum EAccum{} -> err_accum @@ -1059,7 +1064,7 @@ drev des accumMap = \case (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (EZero ext t))) $ + (ezeroD2 t))) $ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) (EVar ext (d2 at') IZ)) @@ -1091,36 +1096,36 @@ drevScoped des accumMap argty argsto argids expr = case argsto of | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> case sub of SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) + SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) SAccum | Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap - , Just Refl <- testEquality foundTy (STAccum argty) + , Just Refl <- testEquality foundTy (STAccum (d2M argty)) , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> RetScoped e0 subtape e1 sub $ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in - ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ + ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) + &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) -- Our contribution to the binding's cotangent _here_ is -- zero, because we're contributing to an earlier binding -- of the same value instead. - (EPair ext e2 (EZero ext argty)) + (EPair ext e2 (ezeroD2 argty)) | let accumMap' = case argids of - Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) + Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> RetScoped e0 subtape e1 sub $ - EWith ext argty (EZero ext argty) $ + EWith ext (d2M argty) (ezeroD2 argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) + &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index b61b5ff..d8a71b5 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -10,9 +10,9 @@ import Data makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e = +makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = makeAccumulators envpro $ - EWith ext t (EZero ext t) e + EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 2c01178..9e7e7f5 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -53,6 +53,7 @@ d1Identity = \case STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl @@ -72,7 +73,7 @@ reassembleD2E (des `DPush` (_, _, SMerge)) e = EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) +reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 7f49cef..74e7dbd 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -14,14 +14,16 @@ type family D1 t where D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t + D1 (TLEither a b) = TLEither (D1 a) (D1 b) type family D2 t where D2 TNil = TNil D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) + D2 (TEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TMaybe (TArr n (D2 t)) D2 (TScal t) = D2s t + D2 (TLEither a b) = TLEither (D2 a) (D2 b) type family D2s t where D2s TI32 = TNil @@ -40,7 +42,7 @@ type family D2E env where type family D2AcE env where D2AcE '[] = '[] - D2AcE (t : env) = TAccum t : D2AcE env + D2AcE (t : env) = TAccum (D2 t) : D2AcE env d1 :: STy t -> STy (D1 t) d1 STNil = STNil @@ -50,32 +52,40 @@ d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil d1e (t `SCons` env) = d1 t `SCons` d1e env +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) + d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) -d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STMaybe (STArr n (d2 t)) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (t `SCons` ts) = d2 t `SCons` d2e ts +d2e = slistMap fromSMTy . d2eM d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts data CHADConfig = CHADConfig diff --git a/src/Compile.hs b/src/Compile.hs index e2d004a..503c342 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -45,7 +45,6 @@ import qualified Prelude import Array import AST import AST.Pretty (ppSTy, ppExpr) -import qualified CHAD.Types as CHAD import Compile.Exec import Data import Interpreter.Rep @@ -230,11 +229,15 @@ genStructName = \t -> "ty_" ++ gen t where STF32 -> "f" STF64 -> "d" STBool -> "b" - gen (STAccum t) = 'C' : gen t + gen (STAccum t) = 'C' : gen (fromSMTy t) + gen (STLEither a b) = 'L' : gen a ++ gen b -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the -- types in the C translation. +-- +-- For accumulation it is important that for struct representations of monoid +-- types, the all-zero-bytes value corresponds to the zero value of that type. genStruct :: String -> STy t -> [StructDecl] genStruct name topty = case topty of STNil -> @@ -247,13 +250,17 @@ genStruct name topty = case topty of [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. + -- TODO: put shape in the main struct, not the buffer; it's constant, after all + -- TODO: no buffer if n = 0 [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" ,StructDecl name (name ++ "_buf *buf;") com] STScal _ -> [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") "" + [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" ,StructDecl name (name ++ "_buf *buf;") com] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] where com = ppSTy 0 topty @@ -278,7 +285,8 @@ genStructs ty = do STMaybe t -> genStructs t STArr _ t -> genStructs t STScal _ -> pure () - STAccum t -> genStructs (CHAD.d2 t) + STAccum t -> genStructs (fromSMTy t) + STLEither a b -> genStructs a >> genStructs b tell (BList (genStruct name ty)) @@ -450,7 +458,7 @@ serialise topty topval ptr off k = serialise a x ptr off $ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k (STEither a _, Left x) -> do - pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b}) + pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) serialise a x ptr (off + alignmentSTy topty) k (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) @@ -485,6 +493,15 @@ serialise topty topval ptr off k = STF64 -> pokeByteOff ptr off (x :: Double) >> k STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k (STAccum{}, _) -> error "Cannot serialise accumulators" + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k -- | Assumes that this is called at the correct alignment. deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -498,7 +515,7 @@ deserialise topty ptr off = return (x, y) STEither a b -> do tag <- peekByteOff @Word8 ptr off - if tag == 0 -- alignment of (a + b) is alignment of (union {a b}) + if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) STMaybe t -> do @@ -524,6 +541,13 @@ deserialise topty ptr off = STF64 -> peekByteOff @Double ptr off STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off STAccum{} -> error "Cannot serialise accumulators" + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" align :: Int -> Int -> Int align a off = (off + a - 1) `div` a * a @@ -555,7 +579,11 @@ metricsSTy (STScal sty) = case sty of STF32 -> (4, 4) STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t -metricsSTy (STAccum t) = metricsSTy t +metricsSTy (STAccum t) = metricsSTy (fromSMTy t) +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -685,6 +713,39 @@ compile' env = \case <> pure (SAsg retvar e3)))) return (CELit retvar) + ELNil _ t1 t2 -> do + name <- emitStruct (STLEither t1 t2) + return $ CEStruct name [("tag", CELit "0")] + + ELInl _ t e -> do + name <- emitStruct (STLEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("l", e1)] + + ELInr _ t e -> do + name <- emitStruct (STLEither t (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "2"), ("r", e1)] + + ELCase _ e a b c -> do + let STLEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b + (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c + ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 <> pure (SAsg retvar e2)) + (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) + (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) + (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) + return (CELit retvar) + EConstArr _ n t (Array sh vec) -> do strname <- emitStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" @@ -734,8 +795,7 @@ compile' env = \case -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) - resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) - [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname @@ -781,8 +841,7 @@ compile' env = \case -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname @@ -833,8 +892,7 @@ compile' env = \case resname <- allocArray "repl1i" Malloc "rep" (SS n) t (Just (CEBinop (CELit shszname) "*" (CELit lenname))) - ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] - ++ [CELit lenname]) + (compileArrShapeComponents n argname ++ [CELit lenname]) ivar <- genName' "i" jvar <- genName' "j" @@ -926,20 +984,20 @@ compile' env = \case zeroRefcountCheck (typeOf e1) "with" name1 emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" - mcopy <- copyForWriting (CHAD.d2 t) name1 + mcopy <- copyForWriting t name1 accname <- genName' "accum" emit $ SVarDecl False actyname accname - (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (CHAD.d2 t)))])]) + (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 resname <- genName' "acret" - emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac")) + emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" - rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t)) + rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] EAccum _ t prj eidx eval eacc -> do @@ -947,156 +1005,180 @@ compile' env = \case nameval <- compileAssign "acval" env eval -- Generate the variable manually because this one has to be non-const. - eacc' <- compile' env eacc - nameacc <- genName' "acac" - emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' - - let -- Expects a variable reference to a value of type @D2 a@. - setZero :: STy a -> String -> CompM () - setZero STNil _ = return () - setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b)) - setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b)) - setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a) - setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a)) - setZero (STScal sty) v = case sty of - STI32 -> return () -- Nil - STI64 -> return () -- Nil + -- TODO: old code: + -- eacc' <- compile' env eacc + -- nameacc <- genName' "acac" + -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' + nameacc <- compileAssign "acac" env eacc + + let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a + -- full zero array with the given zero info (for the type SMTArr n t1). + initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () + initZeroArray n t1 v vzi = do + shszname <- genName' "inacshsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) + newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) + emit $ SAsg v (CELit newarrName) + forM_ (initZeroFromMemset t1) $ \f1 -> do + ivar <- genName' "i" + ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts + + -- If something needs to be done to properly initialise this type to + -- zero after memory has already been initialised to all-zero bytes, + -- returns an action that does so. + -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) + initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) + initZeroFromMemset SMTNil = Nothing + initZeroFromMemset (SMTPair t1 t2) = + case (initZeroFromMemset t1, initZeroFromMemset t2) of + (Nothing, Nothing) -> Nothing + (mf1, mf2) -> Just $ \v vzi -> do + forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") + forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") + initZeroFromMemset SMTLEither{} = Nothing + initZeroFromMemset SMTMaybe{} = Nothing + initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi + initZeroFromMemset SMTScal{} = Nothing + + let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) + initZero :: SMTy a -> String -> String -> CompM () + initZero SMTNil _ _ = return () + initZero (SMTPair t1 t2) v vzi = do + initZero t1 (v++".a") (vzi++".a") + initZero t2 (v++".b") (vzi++".b") + initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi + initZero (SMTScal sty) v _ = case sty of + STI32 -> emit $ SAsg v (CELit "0") + STI64 -> emit $ SAsg v (CELit "0l") STF32 -> emit $ SAsg v (CELit "0.0f") STF64 -> emit $ SAsg v (CELit "0.0") - STBool -> return () -- Nil - setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported" - initD2Pair :: STy a -> STy b -> String -> CompM () - initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b)) - ((), stmts1) <- scope $ setZero a (v++".j.a") - ((), stmts2) <- scope $ setZero b (v++".j.b") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2) - mempty + let -- | Dereference an accumulation value. Sparse components encountered + -- along the way are initialised before proceeding downwards. At the + -- point where we have the projected accumulator position available, + -- the handler will be invoked with a variable name pointing to the + -- projected position. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM () + accumRef _ SAPHere v _ k = k v - initD2Either :: STy a -> STy b -> String -> Either () () -> CompM () - initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b)) - ((), stmts) <- case side of - Left () -> scope $ setZero a (v++".j.l") - Right () -> scope $ setZero b (v++".j.r") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts) - mempty + accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k + accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k - initD2Maybe :: STy a -> String -> CompM () - initD2Maybe a v = do -- Maybe (D2 a) - ((), stmts) <- scope $ setZero a (v++".j") + accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do + ((), stmtsInit1) <- scope $ initZero ta (v++".l") i emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts) - mempty + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty + accumRef ta prj' (v++".l") i k + accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do + ((), stmtsInit2) <- scope $ initZero tb (v++".r") i + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty + accumRef tb prj' (v++".r") i k - -- mind: this has to traverse the D2 of these things, and it also has to - -- initialise data structures that are still sparse in the accumulator. - let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String - accumRef _ SAPHere v _ = pure v - accumRef (STPair ta tb) (SAPFst prj') v i = do - initD2Pair ta tb v - accumRef ta prj' (v++".j.a") i - accumRef (STPair ta tb) (SAPSnd prj') v i = do - initD2Pair ta tb v - accumRef tb prj' (v++".j.b") i - accumRef (STEither ta tb) (SAPLeft prj') v i = do - initD2Either ta tb v (Left ()) - accumRef ta prj' (v++".j.l") i - accumRef (STEither ta tb) (SAPRight prj') v i = do - initD2Either ta tb v (Right ()) - accumRef tb prj' (v++".j.r") i - accumRef (STMaybe tj) (SAPJust prj') v i = do - initD2Maybe tj v - accumRef tj prj' (v++".j") i - accumRef (STArr n t') (SAPArrIdx prj' _) v i = do - (newarrName, newarrStmts) <- scope $ allocArray "accumRef" Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b")) + accumRef (SMTMaybe tj) (SAPJust prj') v i k = do + ((), stmtsInit1) <- scope $ initZero tj (v++".j") i emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) - <> newarrStmts - <> pure (SAsg (v++".j") (CELit newarrName))) - mempty + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty + accumRef tj prj' (v++".j") i k + accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ (zip3 [0::Int ..] (indexTupleComponents n (i++".a.a")) - (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) .||. - CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))) + CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ - v ++ ".j.buf" ++ - concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b") - - -- mind: this has to add the D2 of these things, and it also has to - -- initialise data structures that are still sparse in the accumulator. - let add :: STy a -> String -> String -> CompM () - add STNil _ _ = return () - add (STPair t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a") - ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (stmts1 <> stmts2) - mempty)) - add (STEither t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l") - ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r") + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k + + let -- Add a value (s) into an existing accumulation value (d). If a sparse + -- component of d is encountered, s is simply written there. + add :: SMTy a -> String -> String -> CompM () + add SMTNil _ _ = return () + add (SMTPair t1 t2) d s = do + add t1 (d++".a") (s++".a") + add t2 (d++".b") (s++".b") + add (SMTLEither t1 t2) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s + ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") + ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SAsg (d++".j.tag") (CELit (s++".j.tag"))) - <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0")) - stmts1 stmts2)) - mempty)) - add (STMaybe t1) d s = do + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + ((if emitChecks + then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) + "&&" + (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ + "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ + "return false;") + mempty) + else mempty) + -- note: s may have tag 0 + <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + stmts1 + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) + stmts2 mempty)))) + add (SMTMaybe t1) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SAsg (d++".tag") (CELit "1")) <> stmts1) - mempty)) - add (STArr n t1) d s = do + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) + add (SMTArr n t1) d s = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ [0 .. fromSNat n - 1] $ \j -> do + emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + "!=" + (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ + "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ + d ++ ".buf" ++ + concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + ", " ++ s ++ ".buf" ++ + concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + "); " ++ + "return false;") + mempty + shsizename <- genName' "acshsz" + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")) ivar <- genName' "i" - ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]") - ((), stmtsDecr) <- scope $ incrementVarAlways "accumarr" Decrement (STArr n (CHAD.d2 t1)) (s++".j") - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))) - -- TODO: emit check here for the source being either equal in shape to the destination - <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) - stmts1) - <> stmtsDecr))) - mempty - add (STScal sty) d s = case sty of - STI32 -> return () - STI64 -> return () - STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STBool -> return () - add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" + ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1 + add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - dest <- accumRef t prj (nameacc++".buf->ac") nameidx - add (acPrjTy prj t) dest nameval + accumRef t prj (nameacc++".buf->ac") nameidx $ \dest -> + add (acPrjTy prj t) dest nameval emit $ SVerbatim $ "// compile EAccum end" + incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval + return $ CEStruct (repSTy STNil) [] EError _ t s -> do @@ -1111,9 +1193,9 @@ compile' env = \case name <- emitStruct t return $ CEStruct name [] - EZero{} -> error "Compile: monoid operations should have been eliminated" - EPlus{} -> error "Compile: monoid operations should have been eliminated" - EOneHot{} -> error "Compile: monoid operations should have been eliminated" + EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EIdx1{} -> error "Compile: not implemented: EIdx1" @@ -1144,6 +1226,7 @@ data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array | ATNoop -- ^ don't do anything here | ATProj String ArrayTree -- ^ descend one field deeper | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 | ATBoth ArrayTree ArrayTree -- ^ do both these paths smartATProj :: String -> ArrayTree -> ArrayTree @@ -1154,6 +1237,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree smartATCondTag ATNoop ATNoop = ATNoop smartATCondTag t t' = ATCondTag t t' +smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree +smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop +smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 + smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree smartATBoth ATNoop t = t smartATBoth t ATNoop = t @@ -1169,6 +1256,9 @@ makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTre makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1204,6 +1294,15 @@ incrementVar' marker inc path (ATCondTag t1 t2) = do ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 +incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) + stmts2 + (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) + stmts3 + stmts1)) incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 toLinearIdx :: SNat n -> String -> String -> CExpr @@ -1257,10 +1356,12 @@ compileShapeQuery (SS n) var = -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize SZ _ = CELit "1" -compileArrShapeSize n var = - foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") - | i <- [0 .. fromSNat n - 1]] +compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeComponents :: SNat n -> String -> [CExpr] +compileArrShapeComponents n var = + [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1347,8 +1448,7 @@ compileExtremum nameBase opName operator env e = do -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname @@ -1375,47 +1475,47 @@ compileExtremum nameBase opName operator env e = do -- | If this returns Nothing, there was nothing to copy because making a simple -- value copy in C already makes it suitable to write to. -copyForWriting :: STy t -> String -> CompM (Maybe CExpr) +copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) copyForWriting topty var = case topty of - STNil -> return Nothing + SMTNil -> return Nothing - STPair a b -> do + SMTPair a b -> do e1 <- copyForWriting a (var ++ ".a") e2 <- copyForWriting b (var ++ ".b") case (e1, e2) of (Nothing, Nothing) -> return Nothing - _ -> return $ Just $ CEStruct (repSTy topty) + _ -> return $ Just $ CEStruct toptyname [("a", fromMaybe (CELit (var++".a")) e1) ,("b", fromMaybe (CELit (var++".b")) e2)] - STEither a b -> do + SMTLEither a b -> do (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") case (e1, e2) of (Nothing, Nothing) -> return Nothing _ -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name + emit $ SVarDeclUninit toptyname name emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) (stmts2 - <> pure (SAsg name (CEStruct (repSTy topty) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) return (Just (CELit name)) - STMaybe t -> do + SMTMaybe t -> do (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") case e1 of Nothing -> return Nothing Just e1' -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name + emit $ SVarDeclUninit toptyname name emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) + (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) (stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) return (Just (CELit name)) -- If there are no nested arrays, we know that a refcount of 1 means that the @@ -1423,10 +1523,10 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - STArr n t | not (hasArrays t) -> do + SMTArr n t | not (hasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" - emit $ SVarDeclUninit (repSTy (STArr n t)) name + emit $ SVarDeclUninit toptyname name when debugShapes $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" @@ -1438,11 +1538,11 @@ copyForWriting topty var = case topty of emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) (pure (SAsg name (CELit var))) (let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])]) + ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") @@ -1450,26 +1550,26 @@ copyForWriting topty var = case topty of printCExpr 0 databytes ");"]) return (Just (CELit name)) - STArr n t -> do + SMTArr n t -> do shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes name <- genName - emit $ SVarDecl False (repSTy (STArr n t)) name - (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])]) + emit $ SVarDecl False toptyname name + (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain dstvar <- genName' "cpydst" - emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) srcvar <- genName' "cpysrc" - emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) ivar <- genName' "i" @@ -1484,9 +1584,10 @@ copyForWriting topty var = case topty of return (Just (CELit name)) - STScal _ -> return Nothing + SMTScal _ -> return Nothing - STAccum _ -> error "Compile: Nested accumulators not supported" + where + toptyname = repSTy (fromSMTy topty) zeroRefcountCheck :: STy t -> String -> String -> CompM () zeroRefcountCheck toptyp opname topvar = @@ -1521,6 +1622,14 @@ zeroRefcountCheck toptyp opname topvar = return (BList [s1, s2, s3]) go STScal{} _ = empty go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) combine (MaybeT a) (MaybeT b) = MaybeT $ do diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 2f94076..ebc70d7 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -143,6 +143,10 @@ dfwdDN = \case ENothing _ t -> ENothing ext (dn t) EJust _ e -> EJust ext (dfwdDN e) EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) + ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) + ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) + ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) EConstArr _ n t x -> scalTyCase t (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs index fba92d0..3c76cbe 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -15,6 +15,7 @@ type family DN t where DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t + DN (TLEither a b) = TLEither (DN a) (DN b) type family DNS t where DNS TF32 = TPair (TScal TF32) (TScal TF32) @@ -40,6 +41,7 @@ dn (STScal t) = case t of STI64 -> STScal STI64 STBool -> STScal STBool dn STAccum{} = error "Accum in source program" +dn (STLEither a b) = STLEither (dn a) (dn b) dne :: SList STy env -> SList STy (DNE env) dne SNil = SNil diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 58d79a5..f8e7e98 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -99,6 +99,15 @@ interpret'Rec env = \case EMaybe _ a b e -> let STMaybe t1 = typeOf e in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e + ELNil _ _ _ -> return Nothing + ELInl _ _ e -> Just . Left <$> interpret' env e + ELInr _ _ e -> Just . Right <$> interpret' env e + ELCase _ e a b c -> + let STLEither t1 t2 = typeOf e + in interpret' env e >>= \case + Nothing -> interpret' env a + Just (Left x) -> interpret' (V t1 x `SCons` env) b + Just (Right y) -> interpret' (V t2 y `SCons` env) c EConstArr _ _ _ v -> return v EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a @@ -136,9 +145,9 @@ interpret'Rec env = \case EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ a b - | STArr n _ <- typeOf a - -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EIdx _ a b -> + let STArr n _ = typeOf a + in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e ECustom _ t1 t2 _ pr _ _ e1 e2 -> do @@ -154,8 +163,8 @@ interpret'Rec env = \case val <- interpret' env e2 accum <- interpret' env e3 accumAddSparse t p accum idx val - EZero _ t -> do - return $ zeroD2 t + EZero _ t ezi -> do + return $ zeroD2 t ezi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b @@ -250,7 +259,7 @@ onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = +onehotD2 (SAPArrIdx prj) (STArr n a) idx val = Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) @@ -299,7 +308,7 @@ newAcSparse typ prj idx val = case (typ, prj) of (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val + (STArr n t, SAPArrIdx prj') -> newIORef . Just =<< newAcArray n t prj' idx val (STAccum{}, _) -> error "Accumulators not allowed in source program" @@ -380,7 +389,7 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of (arrayMapM (newAcSparse t1 SAPHere ()) val') (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj' _) -> + (STArr n t1, SAPArrIdx prj') -> let ((arrindex', arrsh'), idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = unTupRepIdx ShNil ShCons n arrsh' diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index be2a4cc..9056901 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -22,6 +22,7 @@ type family Rep t where Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty Rep (TAccum t) = RepAc t + Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) -- Mutable, represents D2 of t. Has an O(1) zero. type family RepAc t where @@ -32,6 +33,7 @@ type family RepAc t where RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t))) RepAc (TScal sty) = RepAcScal sty RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") + RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) type family RepAcScal t where RepAcScal TI32 = () @@ -57,8 +59,8 @@ vUnpair (Value (x, y)) = (Value x, Value y) showValue :: Int -> STy t -> Rep t -> ShowS showValue _ STNil () = showString "()" showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" -showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x -showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x showValue d (STArr _ t) arr = showParen (d > 10) $ @@ -72,7 +74,10 @@ showValue _ (STScal sty) x = case sty of STI32 -> shows x STI64 -> shows x STBool -> shows x -showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSTy 0 t ++ ">" +showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">" +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" diff --git a/src/Language.hs b/src/Language.hs index 4ed4eaa..9fd5dd3 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -169,11 +169,11 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 -with :: forall t a env acname. KnownTy t => NExpr env (D2 t) -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a (D2 t)) -with a (n :-> b) = NEWith (knownTy @t) a n b +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b -accum :: KnownTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownTy p a b c +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a b c (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 84544f8..8bcb5e5 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -72,8 +72,8 @@ data NExpr env t where -> NExpr env t -- accumulation effect on monoids - NEWith :: STy t -> NExpr env (D2 t) -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a (D2 t)) - NEAccum :: STy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil + NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a diff --git a/src/Simplify.hs b/src/Simplify.hs index ea3bb95..228f265 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -19,7 +19,6 @@ import Data.Type.Equality (testEquality) import AST import AST.Count -import CHAD.Types import Data @@ -169,35 +168,33 @@ simplify' = \case (Any True, ENil ext) (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) - EPlus _ _ (EZero _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _) -> acted $ simplify' e + EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e + EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- simplify' e1 e2' <- simplify' e2 simplifyOneHotTerm (OneHotTerm t p e1' e2') - (Any True, EZero ext t) + (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2)) (\e -> (Any True, e)) (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus - EPlus _ STNil _ _ -> (Any True, ENil ext) + EPlus _ SMTNil _ _ -> (Any True, ENil ext) - EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> - acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) - EPlus _ STPair{} ENothing{} e -> acted $ simplify' e - EPlus _ STPair{} e ENothing{} -> acted $ simplify' e + EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> + acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) - EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> - acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) - EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> - acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) - EPlus _ STEither{} ENothing{} e -> acted $ simplify' e - EPlus _ STEither{} e ENothing{} -> acted $ simplify' e + EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> + acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) + EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> + acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) + EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e + EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e - EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> + EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e + EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e + EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e -- fallback recursion EVar _ t i -> pure $ EVar ext t i @@ -212,6 +209,10 @@ simplify' = \case ENothing _ t -> pure $ ENothing ext t EJust _ e -> EJust ext <$> simplify' e EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e + ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t <$> simplify' e + ELInr _ t e -> ELInr ext t <$> simplify' e + ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c @@ -233,7 +234,7 @@ simplify' = \case <*> (let ?accumInScope = False in simplify' c) <*> simplify' e1 <*> simplify' e2 EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EZero _ t -> pure $ EZero ext t + EZero _ t e -> EZero ext t <$> simplify' e EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b EError _ t s -> pure $ EError ext t s @@ -266,6 +267,10 @@ hasAdds = \case ENothing _ _ -> False EJust _ e -> hasAdds e EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e + ELNil _ _ _ -> False + ELInl _ _ e -> hasAdds e + ELInr _ _ e -> hasAdds e + ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c @@ -283,7 +288,7 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b EAccum _ _ _ _ _ _ -> True - EZero _ _ -> False + EZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -300,17 +305,18 @@ checkAccumInScope = \case SNil -> False check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True + check (STLEither s t) = check s || check t data OneHotTerm env p a b where - OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b + OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b deriving instance Show (OneHotTerm env p a b) simplifyOneHotTerm :: OneHotTerm env p a b -> (Any, r) -- ^ Zero case (onehot is actually zero) - -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) + -> (Ex env a -> (Any, r)) -- ^ Trivial case (no zeros in onehot) -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) -> (Any, r) -simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero +simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k | Just Refl <- testEquality (acPrjTy prj1 t1) t2 @@ -318,57 +324,79 @@ simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k -simplifyOneHotTerm (OneHotTerm t SAPHere idx e) kzero ktriv k = case (t, e) of - (STNil, _) -> kzero +simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of + (SMTNil, _) -> kzero - (STPair{}, ENothing _ _) -> kzero - (STPair{}, EJust _ (EPair _ e1 EZero{})) -> - simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) idx e1) kzero ktriv k - (STPair{}, EJust _ (EPair _ EZero{} e2)) -> - simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) idx e2) kzero ktriv k + (SMTPair{}, EPair _ e1 (EZero _ _ ezi)) -> + simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) (EPair ext (ENil ext) ezi) e1) kzero ktriv k + (SMTPair{}, EPair _ (EZero _ _ ezi) e2) -> + simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) (EPair ext ezi (ENil ext)) e2) kzero ktriv k - (STEither{}, ENothing _ _) -> kzero - (STEither{}, EJust _ (EInl _ _ e1)) -> - simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) idx e1) kzero ktriv k - (STEither{}, EJust _ (EInr _ _ e2)) -> - simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) idx e2) kzero ktriv k + (SMTLEither{}, ELNil _ _ _) -> kzero + (SMTLEither{}, ELInl _ _ e1) -> + simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) (ENil ext) e1) kzero ktriv k + (SMTLEither{}, ELInr _ _ e2) -> + simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) (ENil ext) e2) kzero ktriv k - (STMaybe{}, ENothing _ _) -> kzero - (STMaybe{}, EJust _ e1) -> - simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) idx e1) kzero ktriv k + (SMTMaybe{}, ENothing _ _) -> kzero + (SMTMaybe{}, EJust _ e1) -> + simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) (ENil ext) e1) kzero ktriv k - (STArr{}, ENothing _ _) -> kzero - - (STScal STI32, _) -> kzero - (STScal STI64, _) -> kzero - (STScal STF32, EConst _ _ 0.0) -> kzero - (STScal STF64, EConst _ _ 0.0) -> kzero - (STScal STBool, _) -> kzero + (SMTScal STI32, _) -> kzero + (SMTScal STI64, _) -> kzero + (SMTScal STF32, EConst _ _ 0.0) -> kzero + (SMTScal STF64, EConst _ _ 0.0) -> kzero _ -> ktriv e simplifyOneHotTerm term _ _ k = k term -concatOneHots :: STy a +concatOneHots :: SMTy a -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of (_, SAPHere) -> k prj2 idx2 - (STPair a _, SAPFst prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 - (STPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 + (SMTPair a _, SAPFst prj1') -> + concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) + (SMTPair _ b, SAPSnd prj1') -> + concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (STEither a _, SAPLeft prj1') -> + (SMTLEither a _, SAPLeft prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (STEither _ b, SAPRight prj1') -> + (SMTLEither _ b, SAPRight prj1') -> concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - (STMaybe a, SAPJust prj1') -> + (SMTMaybe a, SAPJust prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - (STArr n a, SAPArrIdx prj1' _) -> + (SMTArr _ a, SAPArrIdx prj1') -> concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) + where + -- invariant: AcIdx expression is duplicable + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go t SAPHere _ e = makeZeroInfo t e + go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) + go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) + go SMTLEither{} _ _ _ = ENil ext + go SMTMaybe{} _ _ _ = ENil ext + go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext |