summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/AST
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Accum.hs90
-rw-r--r--src/AST/Count.hs6
-rw-r--r--src/AST/Pretty.hs76
-rw-r--r--src/AST/SplitLets.hs26
-rw-r--r--src/AST/Types.hs51
-rw-r--r--src/AST/UnMonoid.hs145
6 files changed, 280 insertions, 114 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 67c5de7..e84034b 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -8,6 +8,7 @@
module AST.Accum where
import AST.Types
+import CHAD.Types
import Data
@@ -26,35 +27,90 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPHere :: SAcPrj APHere a a
SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
- SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
- SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
+ SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
- -- TODO: This SNat is rather useless, you always have an STy around too
- SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
+ SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
-- TODO:
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
type family AcIdx p t where
AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = AcIdx p a
- AcIdx (APSnd p) (TPair a b) = AcIdx p b
- AcIdx (APLeft p) (TEither a b) = AcIdx p a
- AcIdx (APRight p) (TEither a b) = AcIdx p b
+ AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
+ AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
+ AcIdx (APLeft p) (TLEither a b) = AcIdx p a
+ AcIdx (APRight p) (TLEither a b) = AcIdx p b
AcIdx (APJust p) (TMaybe a) = AcIdx p a
AcIdx (APArrIdx p) (TArr n a) =
- -- ((index, array shape), recursive info)
- TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx)))
+ -- ((index, shapes info), recursive info)
+ TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
(AcIdx p a)
-- AcIdx (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
-acPrjTy :: SAcPrj p a b -> STy a -> STy b
+acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
-acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t
-acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t
-acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t
-acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t
-acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t
-acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t
+acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
+acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
+acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
+acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
+acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
+acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
+
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+
+lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil
+lemZeroInfoD2 STNil = Refl
+lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl
+lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl
+lemZeroInfoD2 (STScal STI32) = Refl
+lemZeroInfoD2 (STScal STI64) = Refl
+lemZeroInfoD2 (STScal STF32) = Refl
+lemZeroInfoD2 (STScal STF64) = Refl
+lemZeroInfoD2 (STScal STBool) = Refl
+lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program"
+lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+
+-- -- | Additional info needed for accumulation. This is empty unless there is
+-- -- sparsity in the monoid.
+-- type family AccumInfo t where
+-- AccumInfo TNil = TNil
+-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
+-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
+-- AccumInfo (TArr n t) = TArr n (AccumInfo t)
+-- AccumInfo (TScal t) = TNil
+
+-- type family PrimalInfo t where
+-- PrimalInfo TNil = TNil
+-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
+-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
+-- PrimalInfo (TScal t) = TNil
+
+-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
+-- tPrimalInfo SMTNil = STNil
+-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
+-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
+-- tPrimalInfo (SMTScal _) = STNil
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index dc8ec72..feaaa1e 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -113,6 +113,10 @@ occCountGeneral onehot unpush alter many = go WId
ENothing _ _ -> mempty
EJust _ e -> re e
EMaybe _ a b e -> re a <> re1 b <> re e
+ ELNil _ _ _ -> mempty
+ ELInl _ _ e -> re e
+ ELInr _ _ e -> re e
+ ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)
EConstArr{} -> mempty
EBuild _ _ a b -> re a <> many (re1 b)
EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
@@ -130,7 +134,7 @@ occCountGeneral onehot unpush alter many = go WId
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
EWith _ _ a b -> re a <> re1 b
EAccum _ _ _ a b e -> re a <> re b <> re e
- EZero _ _ -> mempty
+ EZero _ _ e -> re e
EPlus _ _ a b -> re a <> re b
EOneHot _ _ _ a b -> re a <> re b
EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index fb5e138..b6ad7d2 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where
+module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse, intercalate)
@@ -152,6 +152,31 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e']
+ ELNil _ _ _ -> return (ppString "LNil")
+
+ ELInl _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e'
+
+ ELInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e'
+
+ ELCase _ e a b c -> do
+ e' <- ppExpr' 0 val e
+ let STLEither t1 t2 = typeOf e
+ a' <- ppExpr' 11 val a
+ name1 <- genNameIfUsedIn t1 IZ b
+ b' <- ppExpr' 0 (Const name1 `SCons` val) b
+ name2 <- genNameIfUsedIn t2 IZ c
+ c' <- ppExpr' 0 (Const name2 `SCons` val) c
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "LNil" <+> ppString "->" <+> a'
+ <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b'
+ <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c'
+
EConstArr _ _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -267,15 +292,17 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ _ prj e1 e2 e3 -> do
+ EAccum _ t prj e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj t prj), e1', e2', e3']
- EZero _ t -> return $ ppParen (d > 0) $
- annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t
+ EZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
EPlus _ _ a b -> do
a' <- ppExpr' 11 val a
@@ -283,11 +310,11 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
- EOneHot _ _ prj a b -> do
+ EOneHot _ t prj a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj t prj), a', b']
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
@@ -320,14 +347,14 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
-ppAcPrj :: SAcPrj p a b -> String
-ppAcPrj SAPHere = "@"
-ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)"
-ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)"
-ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj
-ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n)
+ppAcPrj :: SMTy a -> SAcPrj p a b -> String
+ppAcPrj _ SAPHere = "@"
+ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
+ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)"
+ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
+ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -370,7 +397,24 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF32 -> "f32"
STF64 -> "f64"
STBool -> "bool"
-ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t
+ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
+
+ppSMTy :: Int -> SMTy t -> String
+ppSMTy d ty = render $ ppSMTy' d ty
+
+ppSMTy' :: Int -> SMTy t -> Doc q
+ppSMTy' _ SMTNil = ppString "1"
+ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b
+ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b
+ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t
+ppSMTy' d (SMTArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t
+ppSMTy' _ (SMTScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
ppString :: String -> Doc x
ppString = fromString
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index dcba1ad..159934d 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -29,6 +29,9 @@ splitLets' = \sub -> \case
EMaybe x a b e ->
let STMaybe t1 = typeOf e
in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
+ ELCase x e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
@@ -41,6 +44,9 @@ splitLets' = \sub -> \case
EInr x t e -> EInr x t (splitLets' sub e)
ENothing x t -> ENothing x t
EJust x e -> EJust x (splitLets' sub e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (splitLets' sub e)
+ ELInr x t e -> ELInr x t (splitLets' sub e)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
@@ -57,7 +63,7 @@ splitLets' = \sub -> \case
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
- EZero x t -> EZero x t
+ EZero x t ezi -> EZero x t (splitLets' sub ezi)
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
@@ -121,24 +127,26 @@ split typ = case typ of
STArr{} -> other
STScal{} -> other
STAccum{} -> other
+ STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
splitRec :: forall env t. Ex env t -> STy t
-> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t))
-splitRec rhs = \case
+splitRec rhs typ = case typ of
STNil -> (PNil, BTop)
STPair (a :: STy a) (b :: STy b)
| Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env ->
let (p1, bs1) = splitRec (EFst ext rhs) a
(p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
- t@STEither{} -> other t
- t@STMaybe{} -> other t
- t@STArr{} -> other t
- t@STScal{} -> other t
- t@STAccum{} -> other t
+ STEither{} -> other
+ STMaybe{} -> other
+ STArr{} -> other
+ STScal{} -> other
+ STAccum{} -> other
+ STLEither{} -> other
where
- other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t])
- other t = (Point t IZ, BPush BTop (t, rhs))
+ other :: (Pointers (t : env) t, Bindings Ex env '[t])
+ other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index b20fc2d..c8515fc 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -27,6 +27,8 @@ type data Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
| TAccum Ty -- ^ contained type must be a monoid type
+ -- sparse monoid types
+ | TLEither Ty Ty
type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -38,7 +40,9 @@ data STy t where
STMaybe :: STy a -> STy (TMaybe a)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
- STAccum :: STy t -> STy (TAccum t)
+ STAccum :: SMTy t -> STy (TAccum t)
+ -- sparse monoid types
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
deriving instance Show (STy t)
instance GCompare STy where
@@ -56,12 +60,54 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
instance GShow STy where gshowsPrec = defaultGshowsPrec
+-- | Monoid types
+type SMTy :: Ty -> Type
+data SMTy t where
+ SMTNil :: SMTy TNil
+ SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b)
+ -- TODO: call this SMTLEither
+ SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b)
+ SMTMaybe :: SMTy a -> SMTy (TMaybe a)
+ SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
+ SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+deriving instance Show (SMTy t)
+
+instance GCompare SMTy where
+ gcompare = \cases
+ SMTNil SMTNil -> GEQ
+ SMTNil _ -> GLT ; _ SMTNil -> GGT
+ (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT
+ (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT
+ (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a')
+ SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT
+ (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
+ (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
+ -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+
+instance TestEquality SMTy where testEquality = geq
+instance GEq SMTy where geq = defaultGeq
+instance GShow SMTy where gshowsPrec = defaultGshowsPrec
+
+fromSMTy :: SMTy t -> STy t
+fromSMTy = \case
+ SMTNil -> STNil
+ SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2)
+ SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2)
+ SMTMaybe t -> STMaybe (fromSMTy t)
+ SMTArr n t -> STArr n (fromSMTy t)
+ SMTScal sty -> STScal sty
+
data SScalTy t where
STI32 :: SScalTy TI32
STI64 :: SScalTy TI64
@@ -136,6 +182,7 @@ hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = True
+hasArrays (STLEither a b) = hasArrays a || hasArrays b
type family Tup env where
Tup '[] = TNil
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 0da1afc..3d5f544 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -5,13 +5,14 @@
module AST.UnMonoid (unMonoid, zero, plus) where
import AST
-import CHAD.Types
import Data
+-- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them
+-- into their concrete implementations.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
- EZero _ t -> zero t
+ EZero _ t e -> zero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -27,6 +28,10 @@ unMonoid = \case
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unMonoid e)
+ ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
@@ -46,92 +51,94 @@ unMonoid = \case
EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
-zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
-zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
-zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t))
-zero (STArr n t) = ENothing ext (STArr n (d2 t))
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+zero SMTNil _ = ENil ext
+zero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
+zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
+zero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
-plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
+plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
+plus SMTNil _ _ = ENil ext
+plus (SMTPair t1 t2) a b =
+ let t = STPair (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+plus (SMTLEither t1 t2) a b =
+ let t = STLEither (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t IZ)
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
+ (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
(EError ext t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
(EError ext t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus (STMaybe t) a b =
- plusSparse (d2 t) a b $
- plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
-plus (STArr n t) a b =
- plusSparse (STArr n (d2 t)) a b $
- eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ)
- (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (EVar ext (STArr n (d2 t)) IZ)))
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
-
-plusSparse :: STy a
- -> Ex env (TMaybe a) -> Ex env (TMaybe a)
- -> Ex (a : a : env) a
- -> Ex env (TMaybe a)
-plusSparse t a b adder =
+ (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
+plus (SMTMaybe t) a b =
ELet ext b $
EMaybe ext
- (EVar ext (STMaybe t) IZ)
+ (EVar ext (STMaybe (fromSMTy t)) IZ)
(EJust ext
(EMaybe ext
- (EVar ext t IZ)
- (weakenExpr (WCopy (WCopy WSink)) adder)
- (EVar ext (STMaybe t) (IS IZ))))
+ (EVar ext (fromSMTy t) IZ)
+ (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
(weakenExpr WSink a)
+plus (SMTArr _ t) a b =
+ ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ a b
+plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) -> arg
+ (_, SAPHere) ->
+ ELet ext arg $
+ EVar ext (fromSMTy typ) IZ
- (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
- (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
+ (SMTPair t1 t2, SAPFst prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t1 in
+ EPair ext (EVar ext toh IZ)
+ (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
+
+ (SMTPair t1 t2, SAPSnd prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t2 in
+ EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
+ (EVar ext toh IZ)
- (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
- (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
+ (SMTLEither t1 t2, SAPLeft prj) ->
+ ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ (SMTLEither t1 t2, SAPRight prj) ->
+ ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
- (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
+ (SMTMaybe t1, SAPJust prj) ->
+ EJust ext (onehot t1 prj idx arg)
- (STArr n t1, SAPArrIdx prj _) ->
+ (SMTArr n t1, SAPArrIdx prj) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
- EJust ext $
- EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (zero t1)
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
+ zero t1 (EVar ext (tZeroInfo t1) IZ))