diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-01-28 16:58:51 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-01-28 16:58:51 +0100 |
commit | 3e04b03acd5e7138e0f6241133585f22ddb73060 (patch) | |
tree | 57b60cf7a784e3e1ece6c05afecff52eb4beb6db | |
parent | 817cd3c75a2bbbbb355ac33fc7ca3ad8a16bdc92 (diff) |
Pretty-printer that supports extension fields
-rw-r--r-- | chad-fast.cabal | 5 | ||||
-rw-r--r-- | src/AST.hs | 36 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 225 | ||||
-rw-r--r-- | src/Analysis/Identity.hs | 184 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 1 |
5 files changed, 308 insertions, 143 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 63479bb..f15bf1d 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -51,9 +51,14 @@ library containers, deepseq, -- template-haskell, + prettyprinter, process, transformers, vector, + + ansi-terminal, + prettyprinter-ansi-terminal, + text, hs-source-dirs: src default-language: Haskell2010 ghc-options: -Wall @@ -256,6 +256,42 @@ typeOf = \case EError _ t _ -> t +extOf :: Expr x env t -> x t +extOf = \case + EVar x _ _ -> x + ELet x _ _ -> x + EPair x _ _ -> x + EFst x _ -> x + ESnd x _ -> x + ENil x -> x + EInl x _ _ -> x + EInr x _ _ -> x + ECase x _ _ _ -> x + ENothing x _ -> x + EJust x _ -> x + EMaybe x _ _ _ -> x + EConstArr x _ _ _ -> x + EBuild x _ _ _ -> x + EFold1Inner x _ _ _ -> x + ESum1Inner x _ -> x + EUnit x _ -> x + EReplicate1Inner x _ _ -> x + EMaximum1Inner x _ -> x + EMinimum1Inner x _ -> x + EConst x _ _ -> x + EIdx0 x _ -> x + EIdx1 x _ _ -> x + EIdx x _ _ -> x + EShape x _ -> x + EOp x _ _ -> x + ECustom x _ _ _ _ _ _ _ _ -> x + EWith x _ _ -> x + EAccum x _ _ _ _ -> x + EZero x _ -> x + EPlus x _ _ _ -> x + EOneHot x _ _ _ _ -> x + EError x _ _ -> x + -- unSNat :: SNat n -> Nat -- unSNat SZ = Z -- unSNat (SS n) = S (unSNat n) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 24bacdb..4190f32 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -1,16 +1,26 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (ppExpr, ppTy) where +module AST.Pretty (ppExpr, ppTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse) import Data.Functor.Const +import Data.String (fromString) +import Prettyprinter +import Prettyprinter.Render.String + +import qualified Data.Text.Lazy as TL +import qualified Prettyprinter.Render.Terminal as PT +import System.Console.ANSI (hSupportsANSI) +import System.IO (stdout) +import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count @@ -18,6 +28,17 @@ import CHAD.Types import Data +class PrettyX x where + prettyX :: x t -> String + + prettyXsuffix :: x t -> String + prettyXsuffix x = "<" ++ prettyX x ++ ">" + +instance PrettyX (Const ()) where + prettyX _ = "" + prettyXsuffix _ = "" + + type SVal = SList (Const String) newtype M a = M { runM :: Int -> (a, Int) } @@ -43,8 +64,8 @@ genNameIfUsedIn' prefix ty idx ex genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" -ppExpr :: SList f env -> Expr x env t -> String -ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" +ppExpr :: PrettyX x => SList f env -> Expr x env t -> String +ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) where mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil @@ -53,34 +74,35 @@ ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" name <- genName return (Const name `SCons` val) -ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS -ppExpr' d val = \case - EVar _ _ i -> return $ showString $ getConst $ slistIdx val i +ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc +ppExpr' d val expr = case expr of + EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr e@ELet{} -> ppExprLet d val e EPair _ a b -> do a' <- ppExpr' 0 val a b' <- ppExpr' 0 val b - return $ showString "(" . a' . showString ", " . b' . showString ")" + return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr) + (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr) EFst _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "fst " . e' + return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e' ESnd _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "snd " . e' + return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e' - ENil _ -> return $ showString "()" + ENil _ -> return $ ppString "()" EInl _ _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "Inl " . e' + return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e' EInr _ _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "Inr " . e' + return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e' ECase _ e a b -> do e' <- ppExpr' 0 val e @@ -89,15 +111,17 @@ ppExpr' d val = \case a' <- ppExpr' 0 (Const name1 `SCons` val) a name2 <- genNameIfUsedIn t2 IZ b b' <- ppExpr' 0 (Const name2 `SCons` val) b - return $ showParen (d > 0) $ - showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' - . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a' + <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b' - ENothing _ _ -> return $ showString "nothing" + ENothing _ _ -> return $ ppString "Nothing" EJust _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "Just " . e' + return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e' EMaybe _ a b e -> do let STMaybe t = typeOf e @@ -105,18 +129,23 @@ ppExpr' d val = \case name <- genNameIfUsedIn t IZ b b' <- ppExpr' 11 (Const name `SCons` val) b e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ - showString "maybe " . a' . showString " " . b' . showString " " . e' + return $ ppParen (d > 10) $ + ppApp (ppString "maybe" <> ppX expr) [a', b', e'] EConstArr _ _ ty v - | Dict <- scalRepIsShow ty -> return $ showsPrec d v + | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr EBuild _ n a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b - return $ showParen (d > 10) $ - showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" + return $ group $ flatAlt + (ppParen (d > 0) $ + hang 2 $ + annotate AHighlight (ppString "build") <> ppX expr <+> a' + <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" + <> hardline <> e') + (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e']) EFold1Inner _ a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a @@ -124,65 +153,64 @@ ppExpr' d val = \case a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - return $ showParen (d > 10) $ - showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' - . showString ") " . b' . showString " " . c' + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "sum1i " . e' + return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e' EUnit _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "unit " . e' + return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e' EReplicate1Inner _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b - return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b' + return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b'] EMaximum1Inner _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "maximum1i " . e' + return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e' EMinimum1Inner _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "minimum1i " . e' + return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' EConst _ ty v - | Dict <- scalRepIsShow ty -> return $ showsPrec d v + | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr EIdx0 _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "idx0 " . e' + return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e' EIdx1 _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b - return $ showParen (d > 8) $ a' . showString " .! " . b' + return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b' EIdx _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 10 val b - return $ showParen (d > 8) $ - a' . showString " ! " . b' + return $ ppParen (d > 8) $ + a' <+> ppString "!" <+> b' EShape _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "shape " . e' + return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e' EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b - return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b' + return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b' EOp _ op e -> do e' <- ppExpr' 11 val e let ops = case operator op of (Infix, s) -> "(" ++ s ++ ")" (Prefix, s) -> s - return $ showParen (d > 10) $ showString (ops ++ " ") . e' + return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e' ECustom _ t1 t2 t3 a b c e1 e2 -> do en1 <- genNameIfUsedIn t1 (IS IZ) a @@ -196,46 +224,53 @@ ppExpr' d val = \case c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 - return $ showParen (d > 10) $ showString "custom " - . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") " - . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") " - . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") " - . e1' . showString " " - . e2' + return $ ppParen (d > 10) $ + ppApp (ppString "custom" <> ppX expr) + [ppLam [ppString en1, ppString en2] a' + ,ppLam [ppString pn1, ppString pn2] b' + ,ppLam [ppString dn1, ppString dn2] c' + ,e1' + ,e2'] EWith _ e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 e2' <- ppExpr' 0 (Const name `SCons` val) e2 - return $ showParen (d > 10) $ - showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") - . e2' . showString ")" + return $ group $ flatAlt + (ppParen (d > 0) $ + hang 2 $ + annotate AWith (ppString "with") <> ppX expr <+> e1' + <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" + <> hardline <> e2') + (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) EAccum _ i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 - return $ showParen (d > 10) $ - showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] - EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")") + EZero _ t -> return $ parens $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppTy' 0 t <> ppString ")" EPlus _ _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b - return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] EOneHot _ _ i a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b - return $ showParen (d > 10) $ - showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b' + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b'] - EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) + EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) -ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS +ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc ppExprLet d val etop = do - let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS) + let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) collect val' (ELet _ rhs body) = do let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body @@ -246,18 +281,24 @@ ppExprLet d val etop = do (binds, core) <- collect val etop - let (open, close) = case binds of - [_] -> ("", "") - _ -> ("{ ", " }") - return $ showParen (d > 0) $ - showString ("let " ++ open) - . foldr (.) id - (intersperse (showString " ; ") - (map (\(name, _occ, rhs) -> - showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs) - binds)) - . showString (close ++ " in ") - . core + return $ ppParen (d > 0) $ + align $ + annotate AKey (ppString "let") + <+> align (mconcat $ intersperse hardline $ + map (\(name, _occ, rhs) -> + ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs) + binds) + <> hardline <> annotate AKey (ppString "in") <+> core + +ppApp :: ADoc -> [ADoc] -> ADoc +ppApp fun args = group $ fun <+> align (sep args) + +ppLam :: [ADoc] -> ADoc -> ADoc +ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) + <> softline <> body <> ppString ")") + +ppX :: PrettyX x => Expr x env t -> ADoc +ppX expr = ppString $ prettyXsuffix (extOf expr) data Fixity = Prefix | Infix deriving (Show) @@ -281,19 +322,45 @@ operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") ppTy :: Int -> STy t -> String -ppTy d ty = ppTys d ty "" - -ppTys :: Int -> STy t -> ShowS -ppTys _ STNil = showString "1" -ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b -ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b -ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t -ppTys d (STArr n t) = showParen (d > 10) $ - showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t -ppTys _ (STScal sty) = showString $ case sty of +ppTy d ty = render $ ppTy' d ty + +ppTy' :: Int -> STy t -> Doc q +ppTy' _ STNil = ppString "1" +ppTy' d (STPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b +ppTy' d (STEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b +ppTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t +ppTy' d (STArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppTy' 11 t +ppTy' _ (STScal sty) = ppString $ case sty of STI32 -> "i32" STI64 -> "i64" STF32 -> "f32" STF64 -> "f64" STBool -> "bool" -ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t +ppTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t + +ppString :: String -> Doc x +ppString = fromString + +ppParen :: Bool -> Doc x -> Doc x +ppParen True = parens +ppParen False = id + +data Annot = AKey | AWith | AHighlight | AMonoid + deriving (Show) + +annotToANSI :: Annot -> PT.AnsiStyle +annotToANSI AKey = PT.bold +annotToANSI AWith = PT.color PT.Red <> PT.underlined +annotToANSI AHighlight = PT.color PT.Blue +annotToANSI AMonoid = PT.color PT.Green + +type ADoc = Doc Annot + +render :: Doc Annot -> String +render = + (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI + else renderString) + . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } + where + stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 9087143..285cfb8 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -1,11 +1,14 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} module Analysis.Identity ( identityAnalysis, ) where import AST +import AST.Pretty (PrettyX(..)) +import CHAD.Types (d1, d2) import Data import Util.IdGen @@ -19,13 +22,14 @@ data ValId t where VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) + VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value VIArr :: Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) -- | We don't know what this consists of, but it's a value, and let's just -- give it an ID nevertheless. - VIThing :: Int -> ValId t + VIThing :: STy t -> Int -> ValId t instance Eq (ValId t) where VINil == VINil = True @@ -38,33 +42,40 @@ instance Eq (ValId t) where VIEither'{} == _ = False VIMaybe a == VIMaybe a' = a == a' VIMaybe{} == _ = False + VIMaybe' a == VIMaybe' a' = a == a' + VIMaybe'{} == _ = False VIArr i == VIArr i' = i == i' VIArr{} == _ = False VIScal i == VIScal i' = i == i' VIScal{} == _ = False VIAccum i == VIAccum i' = i == i' VIAccum{} == _ = False - VIThing i == VIThing i' = i == i' + VIThing _ i == VIThing _ i' = i == i' VIThing{} == _ = False +instance PrettyX ValId where + prettyX = \case + VINil -> "" + VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")" + VIEither (Left a) -> "(L" ++ prettyX a ++ ")" + VIEither (Right a) -> "(R" ++ prettyX a ++ ")" + VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")" + VIMaybe Nothing -> "N" + VIMaybe (Just a) -> 'J' : prettyX a + VIMaybe' a -> 'M' : prettyX a + VIArr i -> 'A' : show i + VIScal i -> show i + VIAccum i -> 'C' : show i + VIThing _ i -> '{' : show i ++ "}" + -- | Symbolic partial evaluation. identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t identityAnalysis env term = runIdGen 0 $ do - env' <- slistMapA numberConstant env + env' <- slistMapA genIds env snd <$> idana env' term - where - numberConstant :: STy t -> IdGen (ValId t) - numberConstant = \case - STNil -> pure VINil - STPair a b -> VIPair <$> numberConstant a <*> numberConstant b - STEither a b -> VIEither' <$> numberConstant a <*> numberConstant b - STMaybe{} -> VIThing <$> genId - STArr{} -> VIArr <$> genId - STScal{} -> VIScal <$> genId - STAccum{} -> VIAccum <$> genId idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) -idana env = \case +idana env expr = case expr of EVar _ t i -> do let v = slistIdx env i pure (v, EVar v t i) @@ -82,13 +93,13 @@ idana env = \case EFst _ e -> do (v, e') <- idana env e v' <- case v of VIPair v1 _ -> pure v1 - _ -> VIThing <$> genId + _ -> genIds (typeOf expr) pure (v', EFst v' e') ESnd _ e -> do (v, e') <- idana env e v' <- case v of VIPair _ v2 -> pure v2 - _ -> VIThing <$> genId + _ -> genIds (typeOf expr) pure (v', ESnd v' e') ENil _ -> pure (VINil, ENil VINil) @@ -104,33 +115,31 @@ idana env = \case pure (v, EInr v t1 e2') ECase _ e1 e2 e3 -> do + let STEither t1 t2 = typeOf e1 (v1, e1') <- idana env e1 case v1 of VIEither (Left v1') -> do (v2, e2') <- idana (v1' `SCons` env) e2 - scrap <- VIThing <$> genId + scrap <- genIds t2 (_, e3') <- idana (scrap `SCons` env) e3 pure (v2, ECase v2 e1' e2' e3') VIEither (Right v1') -> do - scrap <- VIThing <$> genId + scrap <- genIds t1 (_, e2') <- idana (scrap `SCons` env) e2 (v3, e3') <- idana (v1' `SCons` env) e3 pure (v3, ECase v3 e1' e2' e3') VIEither' v1'l v1'r -> do (_, e2') <- idana (v1'l `SCons` env) e2 (_, e3') <- idana (v1'r `SCons` env) e3 - res <- genId - pure (VIThing res, ECase (VIThing res) e1' e2' e3') - VIThing _ -> do - x2 <- genId - x3 <- genId - (v2, e2') <- idana (VIThing x2 `SCons` env) e2 - (v3, e3') <- idana (VIThing x3 `SCons` env) e3 - if v2 == v3 - then pure (v2, ECase v2 e1' e2' e3') - else do - res <- genId - pure (VIThing res, ECase (VIThing res) e1' e2' e3') + res <- genIds (typeOf expr) + pure (res, ECase res e1' e2' e3') + VIThing _ _ -> do + x2 <- genIds t1 + x3 <- genIds t2 + (v2, e2') <- idana (x2 `SCons` env) e2 + (v3, e3') <- idana (x3 `SCons` env) e3 + res <- unify v2 v3 + pure (res, ECase res e1' e2' e3') ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t) @@ -140,26 +149,29 @@ idana env = \case pure (v, EJust v e1') EMaybe _ e1 e2 e3 -> do + let STMaybe t1 = typeOf e3 (v3, e3') <- idana env e3 case v3 of VIMaybe Nothing -> do (v1, e1') <- idana env e1 - scrap <- VIThing <$> genId + scrap <- genIds t1 (_, e2') <- idana (scrap `SCons` env) e2 pure (v1, EMaybe v1 e1' e2' e3') VIMaybe (Just v3j) -> do (v2, e2') <- idana (v3j `SCons` env) e2 (_, e1') <- idana env e1 pure (v2, EMaybe v2 e1' e2' e3') - VIThing _ -> do + VIMaybe' v3' -> do + (v2, e2') <- idana (v3' `SCons` env) e2 (v1, e1') <- idana env e1 - scrap <- VIThing <$> genId + res <- unify v1 v2 + pure (res, EMaybe res e1' e2' e3') + VIThing _ _ -> do + (v1, e1') <- idana env e1 + scrap <- genIds t1 (v2, e2') <- idana (scrap `SCons` env) e2 - if v1 == v2 - then pure (v2, EMaybe v2 e1' e2' e3') - else do - res <- genId - pure (VIThing res, EMaybe (VIThing res) e1' e2' e3') + res <- unify v1 v2 + pure (res, EMaybe res e1' e2' e3') EConstArr _ dim t arr -> do x1 <- VIArr <$> genId @@ -167,15 +179,16 @@ idana env = \case EBuild _ dim e1 e2 -> do (_, e1') <- idana env e1 - scrap <- VIThing <$> genId - (_, e2') <- idana (scrap `SCons` env) e2 + x1 <- genIds (tTup (sreplicate dim tIx)) + (_, e2') <- idana (x1 `SCons` env) e2 res <- VIArr <$> genId pure (res, EBuild res dim e1' e2') EFold1Inner _ e1 e2 e3 -> do - scrap1 <- VIThing <$> genId - scrap2 <- VIThing <$> genId - (_, e1') <- idana (scrap1 `SCons` scrap2 `SCons` env) e1 + let t1 = typeOf e1 + x1 <- genIds t1 + x2 <- genIds t1 + (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 res <- VIArr <$> genId @@ -213,51 +226,53 @@ idana env = \case EIdx0 _ e1 -> do (_, e1') <- idana env e1 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EIdx0 res e1') EIdx1 _ e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EIdx1 res e1' e2') EIdx _ e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EIdx res e1' e2') EShape _ e1 -> do (_, e1') <- idana env e1 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EShape res e1') - EOp _ op e1 -> do + EOp _ (op :: SOp a t) e1 -> do (_, e1') <- idana env e1 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EOp res op e1') ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do - x1 <- VIThing <$> genId - x2 <- VIThing <$> genId + let t4 = typeOf e1 + x1 <- genIds t2 + x2 <- genIds t1 (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1 - x3 <- VIThing <$> genId - x4 <- VIThing <$> genId + x3 <- genIds (d1 t2) + x4 <- genIds (d1 t1) (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2 - x5 <- VIThing <$> genId - x6 <- VIThing <$> genId + x5 <- genIds (d2 t4) + x6 <- genIds t3 (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3 (_, e4') <- idana env e4 (_, e5') <- idana env e5 - res <- VIThing <$> genId + res <- genIds t4 pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') EWith _ e1 e2 -> do + let t1 = typeOf e1 (_, e1') <- idana env e1 x1 <- VIAccum <$> genId (v2, e2') <- idana (x1 `SCons` env) e2 - x2 <- VIThing <$> genId + x2 <- genIds t1 let res = VIPair v2 x2 pure (res, EWith res e1' e2') @@ -265,25 +280,66 @@ idana env = \case (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - res <- VIThing <$> genId - pure (res, EAccum res i e1' e2' e3') + pure (VINil, EAccum VINil i e1' e2' e3') EZero _ t -> do - res <- VIThing <$> genId + res <- genIds (d2 t) pure (res, EZero res t) EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (d2 t) pure (res, EPlus res t e1' e2') EOneHot _ t i e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (d2 t) pure (res, EOneHot res t i e1' e2') EError _ t s -> do - res <- VIThing <$> genId + res <- genIds t pure (res, EError res t s) + +-- | This value might be either of the two arguments; we don't know which. +unify :: ValId t -> ValId t -> IdGen (ValId t) +unify VINil VINil = pure VINil +unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d +unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b +unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b +unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b +unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a +unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c +unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c +unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b +unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c +unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d +unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing +unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b +unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VIArr i) (VIArr j) | i == j = pure $ VIArr i + | otherwise = VIArr <$> genId +unify (VIScal i) (VIScal j) | i == j = pure $ VIScal i + | otherwise = VIScal <$> genId +unify (VIAccum i) (VIAccum j) | i == j = pure $ VIAccum i + | otherwise = VIAccum <$> genId +unify (VIThing t i) (VIThing _ j) | i == j = pure $ VIThing t i + | otherwise = genIds t +unify (VIThing t _) _ = genIds t +unify _ (VIThing t _) = genIds t + +genIds :: STy t -> IdGen (ValId t) +genIds STNil = pure VINil +genIds (STPair a b) = VIPair <$> genIds a <*> genIds b +genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STMaybe t) = VIMaybe' <$> genIds t +genIds STArr{} = VIArr <$> genId +genIds STScal{} = VIScal <$> genId +genIds STAccum{} = VIAccum <$> genId diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 6662cbf..fd1b6b1 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -86,6 +86,7 @@ data CHADConfig = CHADConfig , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool } + deriving (Show) defaultConfig :: CHADConfig defaultConfig = CHADConfig |