{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} module AST.Pretty (ppExpr, ppTy) where import Control.Monad (ap) import Data.List (intersperse) import Data.Functor.Const import AST import AST.Count import CHAD.Types import Data type SVal = SList (Const String) newtype M a = M { runM :: Int -> (a, Int) } deriving (Functor) instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) } genId :: M Int genId = M (\i -> (i, i + 1)) genName' :: String -> M String genName' prefix = (prefix ++) . show <$> genId genName :: M String genName = genName' "x" genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn' prefix ty idx ex | occCount idx ex == mempty = case ty of STNil -> return "()" _ -> return "_" | otherwise = genName' prefix genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" ppExpr :: SList f env -> Expr x env t -> String ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" where mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil mkVal (SCons _ v) = do val <- mkVal v name <- genName return (Const name `SCons` val) ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS ppExpr' d val = \case EVar _ _ i -> return $ showString $ getConst $ slistIdx val i e@ELet{} -> ppExprLet d val e EPair _ a b -> do a' <- ppExpr' 0 val a b' <- ppExpr' 0 val b return $ showString "(" . a' . showString ", " . b' . showString ")" EFst _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "fst " . e' ESnd _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "snd " . e' ENil _ -> return $ showString "()" EInl _ _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "Inl " . e' EInr _ _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "Inr " . e' ECase _ e a b -> do e' <- ppExpr' 0 val e let STEither t1 t2 = typeOf e name1 <- genNameIfUsedIn t1 IZ a a' <- ppExpr' 0 (Const name1 `SCons` val) a name2 <- genNameIfUsedIn t2 IZ b b' <- ppExpr' 0 (Const name2 `SCons` val) b return $ showParen (d > 0) $ showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" ENothing _ _ -> return $ showString "nothing" EJust _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "Just " . e' EMaybe _ a b e -> do let STMaybe t = typeOf e a' <- ppExpr' 11 val a name <- genNameIfUsedIn t IZ b b' <- ppExpr' 11 (Const name `SCons` val) b e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "maybe " . a' . showString " " . b' . showString " " . e' EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ showsPrec d v EBuild _ n a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" EFold1Inner _ a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c return $ showParen (d > 10) $ showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' . showString ") " . b' . showString " " . c' ESum1Inner _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "sum1i " . e' EUnit _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "unit " . e' EReplicate1Inner _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b' EMaximum1Inner _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "maximum1i " . e' EMinimum1Inner _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "minimum1i " . e' EConst _ ty v | Dict <- scalRepIsShow ty -> return $ showsPrec d v EIdx0 _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "idx0 " . e' EIdx1 _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b return $ showParen (d > 8) $ a' . showString " .! " . b' EIdx _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 10 val b return $ showParen (d > 8) $ a' . showString " ! " . b' EShape _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "shape " . e' EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b' EOp _ op e -> do e' <- ppExpr' 11 val e let ops = case operator op of (Infix, s) -> "(" ++ s ++ ")" (Prefix, s) -> s return $ showParen (d > 10) $ showString (ops ++ " ") . e' ECustom _ t1 t2 t3 a b c e1 e2 -> do en1 <- genNameIfUsedIn t1 (IS IZ) a en2 <- genNameIfUsedIn t2 IZ a pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b pn2 <- genNameIfUsedIn (d1 t2) IZ b dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 return $ showParen (d > 10) $ showString "custom " . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") " . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") " . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") " . e1' . showString " " . e2' EWith _ e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 e2' <- ppExpr' 0 (Const name `SCons` val) e2 return $ showParen (d > 10) $ showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") . e2' . showString ")" EAccum _ i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")") EPlus _ _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' EOneHot _ _ i a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b' EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS ppExprLet d val etop = do let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS) collect val' (ELet _ rhs body) = do let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body rhs' <- ppExpr' 0 val' rhs (binds, core) <- collect (Const name `SCons` val') body return ((name, occ, rhs') : binds, core) collect val' e = ([],) <$> ppExpr' 0 val' e (binds, core) <- collect val etop let (open, close) = case binds of [_] -> ("", "") _ -> ("{ ", " }") return $ showParen (d > 0) $ showString ("let " ++ open) . foldr (.) id (intersperse (showString " ; ") (map (\(name, _occ, rhs) -> showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs) binds)) . showString (close ++ " in ") . core data Fixity = Prefix | Infix deriving (Show) operator :: SOp a t -> (Fixity, String) operator OAdd{} = (Infix, "+") operator OMul{} = (Infix, "*") operator ONeg{} = (Prefix, "negate") operator OLt{} = (Infix, "<") operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") operator OAnd = (Infix, "&&") operator OOr = (Infix, "||") operator OIf = (Prefix, "ifB") operator ORound64 = (Prefix, "round") operator OToFl64 = (Prefix, "toFl64") operator ORecip{} = (Prefix, "recip") operator OExp{} = (Prefix, "exp") operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") ppTy :: Int -> STy t -> String ppTy d ty = ppTys d ty "" ppTys :: Int -> STy t -> ShowS ppTys _ STNil = showString "1" ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t ppTys d (STArr n t) = showParen (d > 10) $ showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t ppTys _ (STScal sty) = showString $ case sty of STI32 -> "i32" STI64 -> "i64" STF32 -> "f32" STF64 -> "f64" STBool -> "bool" ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t