diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 | 
| commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
| tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/AST | |
| parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) | |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 90 | ||||
| -rw-r--r-- | src/AST/Count.hs | 6 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 76 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 26 | ||||
| -rw-r--r-- | src/AST/Types.hs | 51 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 145 | 
6 files changed, 280 insertions, 114 deletions
| diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 67c5de7..e84034b 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -8,6 +8,7 @@  module AST.Accum where  import AST.Types +import CHAD.Types  import Data @@ -26,35 +27,90 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where    SAPHere :: SAcPrj APHere a a    SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b    SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b -  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b -  SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b +  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b +  SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b    SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b -  -- TODO: This SNat is rather useless, you always have an STy around too -  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b +  SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b    -- TODO:    -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)  deriving instance Show (SAcPrj p a b)  type family AcIdx p t where    AcIdx APHere t = TNil -  AcIdx (APFst p) (TPair a b) = AcIdx p a -  AcIdx (APSnd p) (TPair a b) = AcIdx p b -  AcIdx (APLeft p) (TEither a b) = AcIdx p a -  AcIdx (APRight p) (TEither a b) = AcIdx p b +  AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) +  AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) +  AcIdx (APLeft p) (TLEither a b) = AcIdx p a +  AcIdx (APRight p) (TLEither a b) = AcIdx p b    AcIdx (APJust p) (TMaybe a) = AcIdx p a    AcIdx (APArrIdx p) (TArr n a) = -    -- ((index, array shape), recursive info) -    TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) +    -- ((index, shapes info), recursive info) +    TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))            (AcIdx p a)    -- AcIdx (APArrSlice m) (TArr n a) =    --   -- (index, array shape)    --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -acPrjTy :: SAcPrj p a b -> STy a -> STy b +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b  acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where +  ZeroInfo TNil = TNil +  ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) +  ZeroInfo (TLEither a b) = TNil +  ZeroInfo (TMaybe a) = TNil +  ZeroInfo (TArr n t) = TArr n (ZeroInfo t) +  ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil +lemZeroInfoD2 STNil = Refl +lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl +lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl +lemZeroInfoD2 (STScal STI32) = Refl +lemZeroInfoD2 (STScal STI64) = Refl +lemZeroInfoD2 (STScal STF32) = Refl +lemZeroInfoD2 (STScal STF64) = Refl +lemZeroInfoD2 (STScal STBool) = Refl +lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" +lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +--   AccumInfo TNil = TNil +--   AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +--   AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +--   AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +--   AccumInfo (TArr n t) = TArr n (AccumInfo t) +--   AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +--   PrimalInfo TNil = TNil +--   PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +--   PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +--   PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +--   PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +--   PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil diff --git a/src/AST/Count.hs b/src/AST/Count.hs index dc8ec72..feaaa1e 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -113,6 +113,10 @@ occCountGeneral onehot unpush alter many = go WId        ENothing _ _ -> mempty        EJust _ e -> re e        EMaybe _ a b e -> re a <> re1 b <> re e +      ELNil _ _ _ -> mempty +      ELInl _ _ e -> re e +      ELInr _ _ e -> re e +      ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)        EConstArr{} -> mempty        EBuild _ _ a b -> re a <> many (re1 b)        EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c @@ -130,7 +134,7 @@ occCountGeneral onehot unpush alter many = go WId        ECustom _ _ _ _ _ _ _ a b -> re a <> re b        EWith _ _ a b -> re a <> re1 b        EAccum _ _ _ a b e -> re a <> re b <> re e -      EZero _ _ -> mempty +      EZero _ _ e -> re e        EPlus _ _ a b -> re a <> re b        EOneHot _ _ _ a b -> re a <> re b        EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index fb5e138..b6ad7d2 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -7,7 +7,7 @@  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE TupleSections #-}  {-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where +module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where  import Control.Monad (ap)  import Data.List (intersperse, intercalate) @@ -152,6 +152,31 @@ ppExpr' d val expr = case expr of      return $ ppParen (d > 10) $        ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e'] +  ELNil _ _ _ -> return (ppString "LNil") + +  ELInl _ _ e -> do +    e' <- ppExpr' 11 val e +    return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' + +  ELInr _ _ e -> do +    e' <- ppExpr' 11 val e +    return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' + +  ELCase _ e a b c -> do +    e' <- ppExpr' 0 val e +    let STLEither t1 t2 = typeOf e +    a' <- ppExpr' 11 val a +    name1 <- genNameIfUsedIn t1 IZ b +    b' <- ppExpr' 0 (Const name1 `SCons` val) b +    name2 <- genNameIfUsedIn t2 IZ c +    c' <- ppExpr' 0 (Const name2 `SCons` val) c +    return $ ppParen (d > 0) $ +      hang 2 $ +        annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") +        <> hardline <> ppString "LNil" <+> ppString "->" <+> a' +        <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' +        <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' +    EConstArr _ _ ty v      | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -267,15 +292,17 @@ ppExpr' d val expr = case expr of             <> hardline <> e2')          (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) -  EAccum _ _ prj e1 e2 e3 -> do +  EAccum _ t prj e1 e2 e3 -> do      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2      e3' <- ppExpr' 11 val e3      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3'] +      ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj t prj), e1', e2', e3'] -  EZero _ t -> return $ ppParen (d > 0) $ -    annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t +  EZero _ t e1 -> do +    e1' <- ppExpr' 11 val e1 +    return $ ppParen (d > 0) $ +      annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'    EPlus _ _ a b -> do      a' <- ppExpr' 11 val a @@ -283,11 +310,11 @@ ppExpr' d val expr = case expr of      return $ ppParen (d > 10) $        ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] -  EOneHot _ _ prj a b -> do +  EOneHot _ t prj a b -> do      a' <- ppExpr' 11 val a      b' <- ppExpr' 11 val b      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b'] +      ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj t prj), a', b']    EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -320,14 +347,14 @@ ppLam :: [ADoc] -> ADoc -> ADoc  ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])                                            <> softline <> body <> ppString ")") -ppAcPrj :: SAcPrj p a b -> String -ppAcPrj SAPHere = "@" -ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" -ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" -ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj -ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +ppAcPrj :: SMTy a -> SAcPrj p a b -> String +ppAcPrj _ SAPHere = "@" +ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" +ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" +ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj +ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)  ppX :: PrettyX x => Expr x env t -> ADoc  ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -370,7 +397,24 @@ ppSTy' _ (STScal sty) = ppString $ case sty of    STF32 -> "f32"    STF64 -> "f64"    STBool -> "bool" -ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b + +ppSMTy :: Int -> SMTy t -> String +ppSMTy d ty = render $ ppSMTy' d ty + +ppSMTy' :: Int -> SMTy t -> Doc q +ppSMTy' _ SMTNil = ppString "1" +ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b +ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b +ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t +ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ +  ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t +ppSMTy' _ (SMTScal sty) = ppString $ case sty of +  STI32 -> "i32" +  STI64 -> "i64" +  STF32 -> "f32" +  STF64 -> "f64"  ppString :: String -> Doc x  ppString = fromString diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index dcba1ad..159934d 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -29,6 +29,9 @@ splitLets' = \sub -> \case    EMaybe x a b e ->      let STMaybe t1 = typeOf e      in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) +  ELCase x e a b c -> +    let STLEither t1 t2 = typeOf e +    in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)    EFold1Inner x cm a b c ->      let STArr _ t1 = typeOf c      in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) @@ -41,6 +44,9 @@ splitLets' = \sub -> \case    EInr x t e -> EInr x t (splitLets' sub e)    ENothing x t -> ENothing x t    EJust x e -> EJust x (splitLets' sub e) +  ELNil x t1 t2 -> ELNil x t1 t2 +  ELInl x t e -> ELInl x t (splitLets' sub e) +  ELInr x t e -> ELInr x t (splitLets' sub e)    EConstArr x n t a -> EConstArr x n t a    EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)    ESum1Inner x e -> ESum1Inner x (splitLets' sub e) @@ -57,7 +63,7 @@ splitLets' = \sub -> \case    ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)    EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)    EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) -  EZero x t -> EZero x t +  EZero x t ezi -> EZero x t (splitLets' sub ezi)    EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)    EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)    EError x t s -> EError x t s @@ -121,24 +127,26 @@ split typ = case typ of    STArr{} -> other    STScal{} -> other    STAccum{} -> other +  STLEither{} -> other    where      other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])      other = (Point typ IZ, BTop)  splitRec :: forall env t. Ex env t -> STy t           -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs = \case +splitRec rhs typ = case typ of    STNil -> (PNil, BTop)    STPair (a :: STy a) (b :: STy b)      | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env ->          let (p1, bs1) = splitRec (EFst ext rhs) a              (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b          in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) -  t@STEither{} -> other t -  t@STMaybe{} -> other t -  t@STArr{} -> other t -  t@STScal{} -> other t -  t@STAccum{} -> other t +  STEither{} -> other +  STMaybe{} -> other +  STArr{} -> other +  STScal{} -> other +  STAccum{} -> other +  STLEither{} -> other    where -    other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) -    other t = (Point t IZ, BPush BTop (t, rhs)) +    other :: (Pointers (t : env) t, Bindings Ex env '[t]) +    other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index b20fc2d..c8515fc 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -27,6 +27,8 @@ type data Ty    | TArr Nat Ty  -- ^ rank, element type    | TScal ScalTy    | TAccum Ty  -- ^ contained type must be a monoid type +  -- sparse monoid types +  | TLEither Ty Ty  type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -38,7 +40,9 @@ data STy t where    STMaybe :: STy a -> STy (TMaybe a)    STArr :: SNat n -> STy t -> STy (TArr n t)    STScal :: SScalTy t -> STy (TScal t) -  STAccum :: STy t -> STy (TAccum t) +  STAccum :: SMTy t -> STy (TAccum t) +  -- sparse monoid types +  STLEither :: STy a -> STy b -> STy (TLEither a b)  deriving instance Show (STy t)  instance GCompare STy where @@ -56,12 +60,54 @@ instance GCompare STy where      (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')      STScal{} _ -> GLT ; _ STScal{} -> GGT      (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') -    -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT +    STAccum{} _ -> GLT ; _ STAccum{} -> GGT +    (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') +    -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT  instance TestEquality STy where testEquality = geq  instance GEq STy where geq = defaultGeq  instance GShow STy where gshowsPrec = defaultGshowsPrec +-- | Monoid types +type SMTy :: Ty -> Type +data SMTy t where +  SMTNil :: SMTy TNil +  SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) +  -- TODO: call this SMTLEither +  SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) +  SMTMaybe :: SMTy a -> SMTy (TMaybe a) +  SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) +  SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) +deriving instance Show (SMTy t) + +instance GCompare SMTy where +  gcompare = \cases +    SMTNil SMTNil -> GEQ +    SMTNil _ -> GLT ; _ SMTNil -> GGT +    (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') +    SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT +    (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') +    SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT +    (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') +    SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT +    (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') +    SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT +    (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') +    -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + +instance TestEquality SMTy where testEquality = geq +instance GEq SMTy where geq = defaultGeq +instance GShow SMTy where gshowsPrec = defaultGshowsPrec + +fromSMTy :: SMTy t -> STy t +fromSMTy = \case +  SMTNil -> STNil +  SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) +  SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) +  SMTMaybe t -> STMaybe (fromSMTy t) +  SMTArr n t -> STArr n (fromSMTy t) +  SMTScal sty -> STScal sty +  data SScalTy t where    STI32 :: SScalTy TI32    STI64 :: SScalTy TI64 @@ -136,6 +182,7 @@ hasArrays (STMaybe t) = hasArrays t  hasArrays STArr{} = True  hasArrays STScal{} = False  hasArrays STAccum{} = True +hasArrays (STLEither a b) = hasArrays a || hasArrays b  type family Tup env where    Tup '[] = TNil diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 0da1afc..3d5f544 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -5,13 +5,14 @@  module AST.UnMonoid (unMonoid, zero, plus) where  import AST -import CHAD.Types  import Data +-- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them +-- into their concrete implementations.  unMonoid :: Ex env t -> Ex env t  unMonoid = \case -  EZero _ t -> zero t +  EZero _ t e -> zero t e    EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)    EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -27,6 +28,10 @@ unMonoid = \case    ENothing _ t -> ENothing ext t    EJust _ e -> EJust ext (unMonoid e)    EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) +  ELNil _ t1 t2 -> ELNil ext t1 t2 +  ELInl _ t e -> ELInl ext t (unMonoid e) +  ELInr _ t e -> ELInr ext t (unMonoid e) +  ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)    EConstArr _ n t x -> EConstArr ext n t x    EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)    EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) @@ -46,92 +51,94 @@ unMonoid = \case    EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)    EError _ t s -> EError ext t s -zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) -zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) -zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) -zero (STArr n t) = ENothing ext (STArr n (d2 t)) -zero (STScal t) = case t of -  STI32 -> ENil ext -  STI64 -> ENil ext +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +zero SMTNil _ = ENil ext +zero (SMTPair t1 t2) e = +  ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) +                         (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) +zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e +zero (SMTScal t) _ = case t of +  STI32 -> EConst ext STI32 0 +  STI64 -> EConst ext STI64 0    STF32 -> EConst ext STF32 0.0    STF64 -> EConst ext STF64 0.0 -  STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" -plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = -  let t = STPair (d2 t1) (d2 t2) -  in plusSparse t a b $ +plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t +plus SMTNil _ _ = ENil ext +plus (SMTPair t1 t2) a b = +  let t = STPair (fromSMTy t1) (fromSMTy t2) +  in ELet ext a $ +     ELet ext (weakenExpr WSink b) $         EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))                            (EFst ext (EVar ext t IZ)))                   (plus t2 (ESnd ext (EVar ext t (IS IZ)))                            (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = -  let t = STEither (d2 t1) (d2 t2) -  in plusSparse t a b $ -       ECase ext (EVar ext t (IS IZ)) -         (ECase ext (EVar ext t (IS IZ)) -           (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) +plus (SMTLEither t1 t2) a b = +  let t = STLEither (fromSMTy t1) (fromSMTy t2) +  in ELet ext a $ +     ELet ext (weakenExpr WSink b) $ +       ELCase ext (EVar ext t (IS IZ)) +         (EVar ext t IZ) +         (ELCase ext (EVar ext t (IS IZ)) +           (EVar ext t (IS (IS IZ))) +           (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))             (EError ext t "plus l+r")) -         (ECase ext (EVar ext t (IS IZ)) +         (ELCase ext (EVar ext t (IS IZ)) +           (EVar ext t (IS (IS IZ)))             (EError ext t "plus r+l") -           (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus (STMaybe t) a b = -  plusSparse (d2 t) a b $ -    plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) -plus (STArr n t) a b = -  plusSparse (STArr n (d2 t)) a b $ -    eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) -        (EVar ext (STArr n (d2 t)) IZ) -        (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) -             (EVar ext (STArr n (d2 t)) (IS IZ)) -             (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) -               (EVar ext (STArr n (d2 t)) (IS IZ)) -               (EVar ext (STArr n (d2 t)) IZ))) -plus (STScal t) a b = case t of -  STI32 -> ENil ext -  STI64 -> ENil ext -  STF32 -> EOp ext (OAdd STF32) (EPair ext a b) -  STF64 -> EOp ext (OAdd STF64) (EPair ext a b) -  STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" - -plusSparse :: STy a -           -> Ex env (TMaybe a) -> Ex env (TMaybe a) -           -> Ex (a : a : env) a -           -> Ex env (TMaybe a) -plusSparse t a b adder = +           (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) +plus (SMTMaybe t) a b =    ELet ext b $      EMaybe ext -      (EVar ext (STMaybe t) IZ) +      (EVar ext (STMaybe (fromSMTy t)) IZ)        (EJust ext          (EMaybe ext -          (EVar ext t IZ) -          (weakenExpr (WCopy (WCopy WSink)) adder) -          (EVar ext (STMaybe t) (IS IZ)))) +          (EVar ext (fromSMTy t) IZ) +          (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) +          (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))        (weakenExpr WSink a) +plus (SMTArr _ t) a b = +  ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) +           a b +plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t  onehot typ topprj idx arg = case (typ, topprj) of -  (_, SAPHere) -> arg +  (_, SAPHere) -> +    ELet ext arg $ +      EVar ext (fromSMTy typ) IZ -  (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) -  (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) +  (SMTPair t1 t2, SAPFst prj) -> +    ELet ext idx $ +      let tidx = typeOf idx in +      ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ +        let toh = fromSMTy t1 in +        EPair ext (EVar ext toh IZ) +                  (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) +         +  (SMTPair t1 t2, SAPSnd prj) -> +    ELet ext idx $ +      let tidx = typeOf idx in +      ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ +        let toh = fromSMTy t2 in +        EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) +                  (EVar ext toh IZ) -  (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) -  (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) +  (SMTLEither t1 t2, SAPLeft prj) -> +    ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) +  (SMTLEither t1 t2, SAPRight prj) -> +    ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) -  (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) +  (SMTMaybe t1, SAPJust prj) -> +    EJust ext (onehot t1 prj idx arg) -  (STArr n t1, SAPArrIdx prj _) -> +  (SMTArr n t1, SAPArrIdx prj) ->      let tidx = tTup (sreplicate n tIx)      in ELet ext idx $ -         EJust ext $ -           EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ -             eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) -               (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) -               (zero t1) +         EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ +           eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) +             (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) +             (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ +                zero t1 (EVar ext (tZeroInfo t1) IZ)) | 
