summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-01-28 16:58:51 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-01-28 16:58:51 +0100
commit3e04b03acd5e7138e0f6241133585f22ddb73060 (patch)
tree57b60cf7a784e3e1ece6c05afecff52eb4beb6db
parent817cd3c75a2bbbbb355ac33fc7ca3ad8a16bdc92 (diff)
Pretty-printer that supports extension fields
-rw-r--r--chad-fast.cabal5
-rw-r--r--src/AST.hs36
-rw-r--r--src/AST/Pretty.hs225
-rw-r--r--src/Analysis/Identity.hs184
-rw-r--r--src/CHAD/Types.hs1
5 files changed, 308 insertions, 143 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 63479bb..f15bf1d 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -51,9 +51,14 @@ library
containers,
deepseq,
-- template-haskell,
+ prettyprinter,
process,
transformers,
vector,
+
+ ansi-terminal,
+ prettyprinter-ansi-terminal,
+ text,
hs-source-dirs: src
default-language: Haskell2010
ghc-options: -Wall
diff --git a/src/AST.hs b/src/AST.hs
index bcbb19a..99c0681 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -256,6 +256,42 @@ typeOf = \case
EError _ t _ -> t
+extOf :: Expr x env t -> x t
+extOf = \case
+ EVar x _ _ -> x
+ ELet x _ _ -> x
+ EPair x _ _ -> x
+ EFst x _ -> x
+ ESnd x _ -> x
+ ENil x -> x
+ EInl x _ _ -> x
+ EInr x _ _ -> x
+ ECase x _ _ _ -> x
+ ENothing x _ -> x
+ EJust x _ -> x
+ EMaybe x _ _ _ -> x
+ EConstArr x _ _ _ -> x
+ EBuild x _ _ _ -> x
+ EFold1Inner x _ _ _ -> x
+ ESum1Inner x _ -> x
+ EUnit x _ -> x
+ EReplicate1Inner x _ _ -> x
+ EMaximum1Inner x _ -> x
+ EMinimum1Inner x _ -> x
+ EConst x _ _ -> x
+ EIdx0 x _ -> x
+ EIdx1 x _ _ -> x
+ EIdx x _ _ -> x
+ EShape x _ -> x
+ EOp x _ _ -> x
+ ECustom x _ _ _ _ _ _ _ _ -> x
+ EWith x _ _ -> x
+ EAccum x _ _ _ _ -> x
+ EZero x _ -> x
+ EPlus x _ _ _ -> x
+ EOneHot x _ _ _ _ -> x
+ EError x _ _ -> x
+
-- unSNat :: SNat n -> Nat
-- unSNat SZ = Z
-- unSNat (SS n) = S (unSNat n)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 24bacdb..4190f32 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -1,16 +1,26 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (ppExpr, ppTy) where
+module AST.Pretty (ppExpr, ppTy, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse)
import Data.Functor.Const
+import Data.String (fromString)
+import Prettyprinter
+import Prettyprinter.Render.String
+
+import qualified Data.Text.Lazy as TL
+import qualified Prettyprinter.Render.Terminal as PT
+import System.Console.ANSI (hSupportsANSI)
+import System.IO (stdout)
+import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
@@ -18,6 +28,17 @@ import CHAD.Types
import Data
+class PrettyX x where
+ prettyX :: x t -> String
+
+ prettyXsuffix :: x t -> String
+ prettyXsuffix x = "<" ++ prettyX x ++ ">"
+
+instance PrettyX (Const ()) where
+ prettyX _ = ""
+ prettyXsuffix _ = ""
+
+
type SVal = SList (Const String)
newtype M a = M { runM :: Int -> (a, Int) }
@@ -43,8 +64,8 @@ genNameIfUsedIn' prefix ty idx ex
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = genNameIfUsedIn' "x"
-ppExpr :: SList f env -> Expr x env t -> String
-ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
+ppExpr :: PrettyX x => SList f env -> Expr x env t -> String
+ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1)
where
mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
@@ -53,34 +74,35 @@ ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
name <- genName
return (Const name `SCons` val)
-ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
-ppExpr' d val = \case
- EVar _ _ i -> return $ showString $ getConst $ slistIdx val i
+ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
+ppExpr' d val expr = case expr of
+ EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr
e@ELet{} -> ppExprLet d val e
EPair _ a b -> do
a' <- ppExpr' 0 val a
b' <- ppExpr' 0 val b
- return $ showString "(" . a' . showString ", " . b' . showString ")"
+ return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr)
+ (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr)
EFst _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "fst " . e'
+ return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e'
ESnd _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "snd " . e'
+ return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e'
- ENil _ -> return $ showString "()"
+ ENil _ -> return $ ppString "()"
EInl _ _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "Inl " . e'
+ return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e'
EInr _ _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "Inr " . e'
+ return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e'
ECase _ e a b -> do
e' <- ppExpr' 0 val e
@@ -89,15 +111,17 @@ ppExpr' d val = \case
a' <- ppExpr' 0 (Const name1 `SCons` val) a
name2 <- genNameIfUsedIn t2 IZ b
b' <- ppExpr' 0 (Const name2 `SCons` val) b
- return $ showParen (d > 0) $
- showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
- . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a'
+ <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b'
- ENothing _ _ -> return $ showString "nothing"
+ ENothing _ _ -> return $ ppString "Nothing"
EJust _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "Just " . e'
+ return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e'
EMaybe _ a b e -> do
let STMaybe t = typeOf e
@@ -105,18 +129,23 @@ ppExpr' d val = \case
name <- genNameIfUsedIn t IZ b
b' <- ppExpr' 11 (Const name `SCons` val) b
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $
- showString "maybe " . a' . showString " " . b' . showString " " . e'
+ return $ ppParen (d > 10) $
+ ppApp (ppString "maybe" <> ppX expr) [a', b', e']
EConstArr _ _ ty v
- | Dict <- scalRepIsShow ty -> return $ showsPrec d v
+ | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
e' <- ppExpr' 0 (Const name `SCons` val) b
- return $ showParen (d > 10) $
- showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
+ return $ group $ flatAlt
+ (ppParen (d > 0) $
+ hang 2 $
+ annotate AHighlight (ppString "build") <> ppX expr <+> a'
+ <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
+ <> hardline <> e')
+ (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e'])
EFold1Inner _ a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -124,65 +153,64 @@ ppExpr' d val = \case
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
- return $ showParen (d > 10) $
- showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
- . showString ") " . b' . showString " " . c'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "sum1i " . e'
+ return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e'
EUnit _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "unit " . e'
+ return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e'
EReplicate1Inner _ a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
- return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b'
+ return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b']
EMaximum1Inner _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "maximum1i " . e'
+ return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e'
EMinimum1Inner _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "minimum1i " . e'
+ return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
EConst _ ty v
- | Dict <- scalRepIsShow ty -> return $ showsPrec d v
+ | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
EIdx0 _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "idx0 " . e'
+ return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e'
EIdx1 _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 9 val b
- return $ showParen (d > 8) $ a' . showString " .! " . b'
+ return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b'
EIdx _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 10 val b
- return $ showParen (d > 8) $
- a' . showString " ! " . b'
+ return $ ppParen (d > 8) $
+ a' <+> ppString "!" <+> b'
EShape _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "shape " . e'
+ return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e'
EOp _ op (EPair _ a b)
| (Infix, ops) <- operator op -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 9 val b
- return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b'
+ return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b'
EOp _ op e -> do
e' <- ppExpr' 11 val e
let ops = case operator op of
(Infix, s) -> "(" ++ s ++ ")"
(Prefix, s) -> s
- return $ showParen (d > 10) $ showString (ops ++ " ") . e'
+ return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e'
ECustom _ t1 t2 t3 a b c e1 e2 -> do
en1 <- genNameIfUsedIn t1 (IS IZ) a
@@ -196,46 +224,53 @@ ppExpr' d val = \case
c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
- return $ showParen (d > 10) $ showString "custom "
- . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") "
- . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") "
- . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") "
- . e1' . showString " "
- . e2'
+ return $ ppParen (d > 10) $
+ ppApp (ppString "custom" <> ppX expr)
+ [ppLam [ppString en1, ppString en2] a'
+ ,ppLam [ppString pn1, ppString pn2] b'
+ ,ppLam [ppString dn1, ppString dn2] c'
+ ,e1'
+ ,e2']
EWith _ e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
e2' <- ppExpr' 0 (Const name `SCons` val) e2
- return $ showParen (d > 10) $
- showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
- . e2' . showString ")"
+ return $ group $ flatAlt
+ (ppParen (d > 0) $
+ hang 2 $
+ annotate AWith (ppString "with") <> ppX expr <+> e1'
+ <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
+ <> hardline <> e2')
+ (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
EAccum _ i e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
- return $ showParen (d > 10) $
- showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3']
- EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")")
+ EZero _ t -> return $ parens $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppTy' 0 t <> ppString ")"
EPlus _ _ a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
- return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
EOneHot _ _ i a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
- return $ showParen (d > 10) $
- showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b']
- EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
+ EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
-ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
+ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
ppExprLet d val etop = do
- let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS)
+ let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
collect val' (ELet _ rhs body) = do
let occ = occCount IZ body
name <- genNameIfUsedIn (typeOf rhs) IZ body
@@ -246,18 +281,24 @@ ppExprLet d val etop = do
(binds, core) <- collect val etop
- let (open, close) = case binds of
- [_] -> ("", "")
- _ -> ("{ ", " }")
- return $ showParen (d > 0) $
- showString ("let " ++ open)
- . foldr (.) id
- (intersperse (showString " ; ")
- (map (\(name, _occ, rhs) ->
- showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs)
- binds))
- . showString (close ++ " in ")
- . core
+ return $ ppParen (d > 0) $
+ align $
+ annotate AKey (ppString "let")
+ <+> align (mconcat $ intersperse hardline $
+ map (\(name, _occ, rhs) ->
+ ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs)
+ binds)
+ <> hardline <> annotate AKey (ppString "in") <+> core
+
+ppApp :: ADoc -> [ADoc] -> ADoc
+ppApp fun args = group $ fun <+> align (sep args)
+
+ppLam :: [ADoc] -> ADoc -> ADoc
+ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
+ <> softline <> body <> ppString ")")
+
+ppX :: PrettyX x => Expr x env t -> ADoc
+ppX expr = ppString $ prettyXsuffix (extOf expr)
data Fixity = Prefix | Infix
deriving (Show)
@@ -281,19 +322,45 @@ operator OLog{} = (Prefix, "log")
operator OIDiv{} = (Infix, "`div`")
ppTy :: Int -> STy t -> String
-ppTy d ty = ppTys d ty ""
-
-ppTys :: Int -> STy t -> ShowS
-ppTys _ STNil = showString "1"
-ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b
-ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b
-ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t
-ppTys d (STArr n t) = showParen (d > 10) $
- showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t
-ppTys _ (STScal sty) = showString $ case sty of
+ppTy d ty = render $ ppTy' d ty
+
+ppTy' :: Int -> STy t -> Doc q
+ppTy' _ STNil = ppString "1"
+ppTy' d (STPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
+ppTy' d (STEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
+ppTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
+ppTy' d (STArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppTy' 11 t
+ppTy' _ (STScal sty) = ppString $ case sty of
STI32 -> "i32"
STI64 -> "i64"
STF32 -> "f32"
STF64 -> "f64"
STBool -> "bool"
-ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t
+ppTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
+
+ppString :: String -> Doc x
+ppString = fromString
+
+ppParen :: Bool -> Doc x -> Doc x
+ppParen True = parens
+ppParen False = id
+
+data Annot = AKey | AWith | AHighlight | AMonoid
+ deriving (Show)
+
+annotToANSI :: Annot -> PT.AnsiStyle
+annotToANSI AKey = PT.bold
+annotToANSI AWith = PT.color PT.Red <> PT.underlined
+annotToANSI AHighlight = PT.color PT.Blue
+annotToANSI AMonoid = PT.color PT.Green
+
+type ADoc = Doc Annot
+
+render :: Doc Annot -> String
+render =
+ (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI
+ else renderString)
+ . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 }
+ where
+ stdoutTTY = unsafePerformIO $ hSupportsANSI stdout
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 9087143..285cfb8 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -1,11 +1,14 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
module Analysis.Identity (
identityAnalysis,
) where
import AST
+import AST.Pretty (PrettyX(..))
+import CHAD.Types (d1, d2)
import Data
import Util.IdGen
@@ -19,13 +22,14 @@ data ValId t where
VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
+ VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
VIArr :: Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
-- | We don't know what this consists of, but it's a value, and let's just
-- give it an ID nevertheless.
- VIThing :: Int -> ValId t
+ VIThing :: STy t -> Int -> ValId t
instance Eq (ValId t) where
VINil == VINil = True
@@ -38,33 +42,40 @@ instance Eq (ValId t) where
VIEither'{} == _ = False
VIMaybe a == VIMaybe a' = a == a'
VIMaybe{} == _ = False
+ VIMaybe' a == VIMaybe' a' = a == a'
+ VIMaybe'{} == _ = False
VIArr i == VIArr i' = i == i'
VIArr{} == _ = False
VIScal i == VIScal i' = i == i'
VIScal{} == _ = False
VIAccum i == VIAccum i' = i == i'
VIAccum{} == _ = False
- VIThing i == VIThing i' = i == i'
+ VIThing _ i == VIThing _ i' = i == i'
VIThing{} == _ = False
+instance PrettyX ValId where
+ prettyX = \case
+ VINil -> ""
+ VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")"
+ VIEither (Left a) -> "(L" ++ prettyX a ++ ")"
+ VIEither (Right a) -> "(R" ++ prettyX a ++ ")"
+ VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")"
+ VIMaybe Nothing -> "N"
+ VIMaybe (Just a) -> 'J' : prettyX a
+ VIMaybe' a -> 'M' : prettyX a
+ VIArr i -> 'A' : show i
+ VIScal i -> show i
+ VIAccum i -> 'C' : show i
+ VIThing _ i -> '{' : show i ++ "}"
+
-- | Symbolic partial evaluation.
identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
identityAnalysis env term = runIdGen 0 $ do
- env' <- slistMapA numberConstant env
+ env' <- slistMapA genIds env
snd <$> idana env' term
- where
- numberConstant :: STy t -> IdGen (ValId t)
- numberConstant = \case
- STNil -> pure VINil
- STPair a b -> VIPair <$> numberConstant a <*> numberConstant b
- STEither a b -> VIEither' <$> numberConstant a <*> numberConstant b
- STMaybe{} -> VIThing <$> genId
- STArr{} -> VIArr <$> genId
- STScal{} -> VIScal <$> genId
- STAccum{} -> VIAccum <$> genId
idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
-idana env = \case
+idana env expr = case expr of
EVar _ t i -> do
let v = slistIdx env i
pure (v, EVar v t i)
@@ -82,13 +93,13 @@ idana env = \case
EFst _ e -> do
(v, e') <- idana env e
v' <- case v of VIPair v1 _ -> pure v1
- _ -> VIThing <$> genId
+ _ -> genIds (typeOf expr)
pure (v', EFst v' e')
ESnd _ e -> do
(v, e') <- idana env e
v' <- case v of VIPair _ v2 -> pure v2
- _ -> VIThing <$> genId
+ _ -> genIds (typeOf expr)
pure (v', ESnd v' e')
ENil _ -> pure (VINil, ENil VINil)
@@ -104,33 +115,31 @@ idana env = \case
pure (v, EInr v t1 e2')
ECase _ e1 e2 e3 -> do
+ let STEither t1 t2 = typeOf e1
(v1, e1') <- idana env e1
case v1 of
VIEither (Left v1') -> do
(v2, e2') <- idana (v1' `SCons` env) e2
- scrap <- VIThing <$> genId
+ scrap <- genIds t2
(_, e3') <- idana (scrap `SCons` env) e3
pure (v2, ECase v2 e1' e2' e3')
VIEither (Right v1') -> do
- scrap <- VIThing <$> genId
+ scrap <- genIds t1
(_, e2') <- idana (scrap `SCons` env) e2
(v3, e3') <- idana (v1' `SCons` env) e3
pure (v3, ECase v3 e1' e2' e3')
VIEither' v1'l v1'r -> do
(_, e2') <- idana (v1'l `SCons` env) e2
(_, e3') <- idana (v1'r `SCons` env) e3
- res <- genId
- pure (VIThing res, ECase (VIThing res) e1' e2' e3')
- VIThing _ -> do
- x2 <- genId
- x3 <- genId
- (v2, e2') <- idana (VIThing x2 `SCons` env) e2
- (v3, e3') <- idana (VIThing x3 `SCons` env) e3
- if v2 == v3
- then pure (v2, ECase v2 e1' e2' e3')
- else do
- res <- genId
- pure (VIThing res, ECase (VIThing res) e1' e2' e3')
+ res <- genIds (typeOf expr)
+ pure (res, ECase res e1' e2' e3')
+ VIThing _ _ -> do
+ x2 <- genIds t1
+ x3 <- genIds t2
+ (v2, e2') <- idana (x2 `SCons` env) e2
+ (v3, e3') <- idana (x3 `SCons` env) e3
+ res <- unify v2 v3
+ pure (res, ECase res e1' e2' e3')
ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t)
@@ -140,26 +149,29 @@ idana env = \case
pure (v, EJust v e1')
EMaybe _ e1 e2 e3 -> do
+ let STMaybe t1 = typeOf e3
(v3, e3') <- idana env e3
case v3 of
VIMaybe Nothing -> do
(v1, e1') <- idana env e1
- scrap <- VIThing <$> genId
+ scrap <- genIds t1
(_, e2') <- idana (scrap `SCons` env) e2
pure (v1, EMaybe v1 e1' e2' e3')
VIMaybe (Just v3j) -> do
(v2, e2') <- idana (v3j `SCons` env) e2
(_, e1') <- idana env e1
pure (v2, EMaybe v2 e1' e2' e3')
- VIThing _ -> do
+ VIMaybe' v3' -> do
+ (v2, e2') <- idana (v3' `SCons` env) e2
(v1, e1') <- idana env e1
- scrap <- VIThing <$> genId
+ res <- unify v1 v2
+ pure (res, EMaybe res e1' e2' e3')
+ VIThing _ _ -> do
+ (v1, e1') <- idana env e1
+ scrap <- genIds t1
(v2, e2') <- idana (scrap `SCons` env) e2
- if v1 == v2
- then pure (v2, EMaybe v2 e1' e2' e3')
- else do
- res <- genId
- pure (VIThing res, EMaybe (VIThing res) e1' e2' e3')
+ res <- unify v1 v2
+ pure (res, EMaybe res e1' e2' e3')
EConstArr _ dim t arr -> do
x1 <- VIArr <$> genId
@@ -167,15 +179,16 @@ idana env = \case
EBuild _ dim e1 e2 -> do
(_, e1') <- idana env e1
- scrap <- VIThing <$> genId
- (_, e2') <- idana (scrap `SCons` env) e2
+ x1 <- genIds (tTup (sreplicate dim tIx))
+ (_, e2') <- idana (x1 `SCons` env) e2
res <- VIArr <$> genId
pure (res, EBuild res dim e1' e2')
EFold1Inner _ e1 e2 e3 -> do
- scrap1 <- VIThing <$> genId
- scrap2 <- VIThing <$> genId
- (_, e1') <- idana (scrap1 `SCons` scrap2 `SCons` env) e1
+ let t1 = typeOf e1
+ x1 <- genIds t1
+ x2 <- genIds t1
+ (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
res <- VIArr <$> genId
@@ -213,51 +226,53 @@ idana env = \case
EIdx0 _ e1 -> do
(_, e1') <- idana env e1
- res <- VIThing <$> genId
+ res <- genIds (typeOf expr)
pure (res, EIdx0 res e1')
EIdx1 _ e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- VIThing <$> genId
+ res <- genIds (typeOf expr)
pure (res, EIdx1 res e1' e2')
EIdx _ e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- VIThing <$> genId
+ res <- genIds (typeOf expr)
pure (res, EIdx res e1' e2')
EShape _ e1 -> do
(_, e1') <- idana env e1
- res <- VIThing <$> genId
+ res <- genIds (typeOf expr)
pure (res, EShape res e1')
- EOp _ op e1 -> do
+ EOp _ (op :: SOp a t) e1 -> do
(_, e1') <- idana env e1
- res <- VIThing <$> genId
+ res <- genIds (typeOf expr)
pure (res, EOp res op e1')
ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do
- x1 <- VIThing <$> genId
- x2 <- VIThing <$> genId
+ let t4 = typeOf e1
+ x1 <- genIds t2
+ x2 <- genIds t1
(_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1
- x3 <- VIThing <$> genId
- x4 <- VIThing <$> genId
+ x3 <- genIds (d1 t2)
+ x4 <- genIds (d1 t1)
(_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2
- x5 <- VIThing <$> genId
- x6 <- VIThing <$> genId
+ x5 <- genIds (d2 t4)
+ x6 <- genIds t3
(_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3
(_, e4') <- idana env e4
(_, e5') <- idana env e5
- res <- VIThing <$> genId
+ res <- genIds t4
pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
EWith _ e1 e2 -> do
+ let t1 = typeOf e1
(_, e1') <- idana env e1
x1 <- VIAccum <$> genId
(v2, e2') <- idana (x1 `SCons` env) e2
- x2 <- VIThing <$> genId
+ x2 <- genIds t1
let res = VIPair v2 x2
pure (res, EWith res e1' e2')
@@ -265,25 +280,66 @@ idana env = \case
(_, e1') <- idana env e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
- res <- VIThing <$> genId
- pure (res, EAccum res i e1' e2' e3')
+ pure (VINil, EAccum VINil i e1' e2' e3')
EZero _ t -> do
- res <- VIThing <$> genId
+ res <- genIds (d2 t)
pure (res, EZero res t)
EPlus _ t e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- VIThing <$> genId
+ res <- genIds (d2 t)
pure (res, EPlus res t e1' e2')
EOneHot _ t i e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- VIThing <$> genId
+ res <- genIds (d2 t)
pure (res, EOneHot res t i e1' e2')
EError _ t s -> do
- res <- VIThing <$> genId
+ res <- genIds t
pure (res, EError res t s)
+
+-- | This value might be either of the two arguments; we don't know which.
+unify :: ValId t -> ValId t -> IdGen (ValId t)
+unify VINil VINil = pure VINil
+unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d
+unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b
+unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b
+unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b
+unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a
+unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c
+unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c
+unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b
+unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c
+unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d
+unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing
+unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b
+unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a
+unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a
+unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a
+unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
+unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
+unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VIArr i) (VIArr j) | i == j = pure $ VIArr i
+ | otherwise = VIArr <$> genId
+unify (VIScal i) (VIScal j) | i == j = pure $ VIScal i
+ | otherwise = VIScal <$> genId
+unify (VIAccum i) (VIAccum j) | i == j = pure $ VIAccum i
+ | otherwise = VIAccum <$> genId
+unify (VIThing t i) (VIThing _ j) | i == j = pure $ VIThing t i
+ | otherwise = genIds t
+unify (VIThing t _) _ = genIds t
+unify _ (VIThing t _) = genIds t
+
+genIds :: STy t -> IdGen (ValId t)
+genIds STNil = pure VINil
+genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
+genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STMaybe t) = VIMaybe' <$> genIds t
+genIds STArr{} = VIArr <$> genId
+genIds STScal{} = VIScal <$> genId
+genIds STAccum{} = VIAccum <$> genId
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 6662cbf..fd1b6b1 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -86,6 +86,7 @@ data CHADConfig = CHADConfig
, -- | Introduce top-level arguments containing arrays in accumulator mode.
chcArgArrayAccum :: Bool
}
+ deriving (Show)
defaultConfig :: CHADConfig
defaultConfig = CHADConfig