From 9ed1fe1a12831896dc9d010a59eb16d016984a26 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 25 Apr 2025 13:35:35 +0200 Subject: Don't unSTy --- src/AST/Pretty.hs | 36 +++++++---------- src/AST/Types.hs | 116 +++++++++++++++++++----------------------------------- src/Compile.hs | 10 ++--- src/Data.hs | 31 +++++++++++---- 4 files changed, 84 insertions(+), 109 deletions(-) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 4f637f2..01dfcf8 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -7,7 +7,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where +module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) @@ -354,28 +354,22 @@ operator OIDiv{} = (Infix, "`div`") operator OMod{} = (Infix, "`mod`") ppSTy :: Int -> STy t -> String -ppSTy d ty = ppTy d (unSTy ty) +ppSTy d ty = render $ ppSTy' d ty ppSTy' :: Int -> STy t -> Doc q -ppSTy' d ty = ppTy' d (unSTy ty) - -ppTy :: Int -> Ty -> String -ppTy d ty = render $ ppTy' d ty - -ppTy' :: Int -> Ty -> Doc q -ppTy' _ TNil = ppString "1" -ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b -ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b -ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t -ppTy' d (TArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t -ppTy' _ (TScal sty) = ppString $ case sty of - TI32 -> "i32" - TI64 -> "i64" - TF32 -> "f32" - TF64 -> "f64" - TBool -> "bool" -ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t +ppSTy' _ STNil = ppString "1" +ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b +ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t +ppSTy' d (STArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t +ppSTy' _ (STScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + STBool -> "bool" +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t ppString :: String -> Doc x ppString = fromString diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 217b2f5..b20fc2d 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -1,34 +1,34 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-} module AST.Types where import Data.Int (Int32, Int64) +import Data.GADT.Compare import Data.GADT.Show import Data.Kind (Type) -import Data.Some import Data.Type.Equality import Data -data Ty +type data Ty = TNil | TPair Ty Ty | TEither Ty Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy - | TAccum Ty -- ^ the accumulator contains D2 of this type - deriving (Show, Eq, Ord) + | TAccum Ty -- ^ contained type must be a monoid type -data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool - deriving (Show, Eq, Ord) +type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool type STy :: Ty -> Type data STy t where @@ -41,22 +41,25 @@ data STy t where STAccum :: STy t -> STy (TAccum t) deriving instance Show (STy t) -instance TestEquality STy where - testEquality STNil STNil = Just Refl - testEquality STNil _ = Nothing - testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality STPair{} _ = Nothing - testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality STEither{} _ = Nothing - testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl - testEquality STMaybe{} _ = Nothing - testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality STArr{} _ = Nothing - testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl - testEquality STScal{} _ = Nothing - testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl - testEquality STAccum{} _ = Nothing - +instance GCompare STy where + gcompare = \cases + STNil STNil -> GEQ + STNil _ -> GLT ; _ STNil -> GGT + (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STPair{} _ -> GLT ; _ STPair{} -> GGT + (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STEither{} _ -> GLT ; _ STEither{} -> GGT + (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') + STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT + (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + STArr{} _ -> GLT ; _ STArr{} -> GGT + (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') + STScal{} _ -> GLT ; _ STScal{} -> GGT + (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') + -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + +instance TestEquality STy where testEquality = geq +instance GEq STy where geq = defaultGeq instance GShow STy where gshowsPrec = defaultGshowsPrec data SScalTy t where @@ -67,14 +70,21 @@ data SScalTy t where STBool :: SScalTy TBool deriving instance Show (SScalTy t) -instance TestEquality SScalTy where - testEquality STI32 STI32 = Just Refl - testEquality STI64 STI64 = Just Refl - testEquality STF32 STF32 = Just Refl - testEquality STF64 STF64 = Just Refl - testEquality STBool STBool = Just Refl - testEquality _ _ = Nothing - +instance GCompare SScalTy where + gcompare = \cases + STI32 STI32 -> GEQ + STI32 _ -> GLT ; _ STI32 -> GGT + STI64 STI64 -> GEQ + STI64 _ -> GLT ; _ STI64 -> GGT + STF32 STF32 -> GEQ + STF32 _ -> GLT ; _ STF32 -> GGT + STF64 STF64 -> GEQ + STF64 _ -> GLT ; _ STF64 -> GGT + STBool STBool -> GEQ + -- STBool _ -> GLT ; _ STBool -> GGT + +instance TestEquality SScalTy where testEquality = geq +instance GEq SScalTy where geq = defaultGeq instance GShow SScalTy where gshowsPrec = defaultGshowsPrec scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) @@ -89,50 +99,6 @@ type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 -unSTy :: STy t -> Ty -unSTy = \case - STNil -> TNil - STPair a b -> TPair (unSTy a) (unSTy b) - STEither a b -> TEither (unSTy a) (unSTy b) - STMaybe t -> TMaybe (unSTy t) - STArr n t -> TArr (unSNat n) (unSTy t) - STScal t -> TScal (unSScalTy t) - STAccum t -> TAccum (unSTy t) - -unSEnv :: SList STy env -> [Ty] -unSEnv SNil = [] -unSEnv (SCons t l) = unSTy t : unSEnv l - -unSScalTy :: SScalTy t -> ScalTy -unSScalTy = \case - STI32 -> TI32 - STI64 -> TI64 - STF32 -> TF32 - STF64 -> TF64 - STBool -> TBool - -reSTy :: Ty -> Some STy -reSTy = \case - TNil -> Some STNil - TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b' - TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b' - TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t' - TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t' - TScal t | Some t' <- reSScalTy t -> Some $ STScal t' - TAccum t | Some t' <- reSTy t -> Some $ STAccum t' - -reSEnv :: [Ty] -> Some (SList STy) -reSEnv [] = Some SNil -reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env) - -reSScalTy :: ScalTy -> Some SScalTy -reSScalTy = \case - TI32 -> Some STI32 - TI64 -> Some STI64 - TF32 -> Some STF32 - TF64 -> Some STF64 - TBool -> Some STBool - type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 @@ -161,7 +127,7 @@ type family ScalIsIntegral t where ScalIsIntegral TF64 = False ScalIsIntegral TBool = False --- | Returns true for arrays /and/ accumulators; +-- | Returns true for arrays /and/ accumulators. hasArrays :: STy t' -> Bool hasArrays STNil = False hasArrays (STPair a b) = hasArrays a || hasArrays b diff --git a/src/Compile.hs b/src/Compile.hs index e3eb207..e2d004a 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -282,11 +282,11 @@ genStructs ty = do tell (BList (genStruct name ty)) -genAllStructs :: Foldable t => t Ty -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty +genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] +genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty data CompState = CompState - { csStructs :: Set Ty + { csStructs :: Set (Some STy) , csTopLevelDecls :: Bag String , csStmts :: Bag Stmt , csNextId :: Int } @@ -329,7 +329,7 @@ scope m = do emitStruct :: STy t -> CompM String emitStruct ty = CompM $ do - modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) } + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } return (genStructName ty) emitTLD :: String -> CompM () @@ -348,7 +348,7 @@ compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets) compileToString codeID env expr = let args = nameEnv env (res, s) = runCompM (compile' args expr) - structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) + structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env)) (arg_pairs, arg_metrics) = unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) diff --git a/src/Data.hs b/src/Data.hs index e7b3148..e86aaa6 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -11,6 +11,8 @@ module Data (module Data, (:~:)(Refl)) where import Data.Functor.Product +import Data.GADT.Compare +import Data.GADT.Show import Data.Some import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) @@ -73,10 +75,15 @@ data SNat n where SS :: SNat n -> SNat (S n) deriving instance Show (SNat n) -instance TestEquality SNat where - testEquality SZ SZ = Just Refl - testEquality (SS n) (SS n') | Just Refl <- testEquality n n' = Just Refl - testEquality _ _ = Nothing +instance GCompare SNat where + gcompare SZ SZ = GEQ + gcompare SZ _ = GLT + gcompare _ SZ = GGT + gcompare (SS n) (SS n') = gorderingLift1 (gcompare n n') + +instance TestEquality SNat where testEquality = geq +instance GEq SNat where geq = defaultGeq +instance GShow SNat where gshowsPrec = defaultGshowsPrec fromSNat :: SNat n -> Int fromSNat SZ = 0 @@ -90,10 +97,6 @@ reSNat :: Nat -> Some SNat reSNat Z = Some SZ reSNat (S n) | Some n' <- reSNat n = Some (SS n') -fromNat :: Nat -> Int -fromNat Z = 0 -fromNat (S m) = succ (fromNat m) - class KnownNat n where knownNat :: SNat n instance KnownNat Z where knownNat = SZ instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat @@ -155,6 +158,18 @@ vecInit (x :< xs@(_ :< _)) = x :< vecInit xs unsafeCoerceRefl :: a :~: b unsafeCoerceRefl = unsafeCoerce Refl +gorderingLift1 :: GOrdering a a' -> GOrdering (f a) (f a') +gorderingLift1 GLT = GLT +gorderingLift1 GGT = GGT +gorderingLift1 GEQ = GEQ + +gorderingLift2 :: GOrdering a a' -> GOrdering b b' -> GOrdering (f a b) (f a' b') +gorderingLift2 GLT _ = GLT +gorderingLift2 GGT _ = GGT +gorderingLift2 GEQ GLT = GLT +gorderingLift2 GEQ GGT = GGT +gorderingLift2 GEQ GEQ = GEQ + data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t] deriving (Show, Functor, Foldable, Traversable) -- cgit v1.2.3-70-g09d2