summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs1
-rw-r--r--src/AST/Pretty.hs41
-rw-r--r--src/AST/Types.hs116
-rw-r--r--src/Compile.hs10
-rw-r--r--src/Data.hs31
-rw-r--r--src/Simplify.hs32
6 files changed, 116 insertions, 115 deletions
diff --git a/src/AST.hs b/src/AST.hs
index fb5a45e..b8d23b4 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -89,7 +89,6 @@ data Expr x env t where
-- accumulation effect on monoids
EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t))
- -- TODO: let this contain a OneHotTerm that is shared with EOneHot for uniformity in Simplify
EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 4f637f2..fb5e138 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where
+module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse, intercalate)
@@ -159,13 +159,14 @@ ppExpr' d val expr = case expr of
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
e' <- ppExpr' 0 (Const name `SCons` val) b
+ let primName = ppString ("build" ++ intSubscript (fromSNat n))
return $ ppParen (d > 0) $
group $ flatAlt
(hang 2 $
- annotate AHighlight (ppString "build") <> ppX expr <+> a'
+ annotate AHighlight primName <> ppX expr <+> a'
<+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
<> hardline <> e')
- (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e'])
+ (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -354,28 +355,22 @@ operator OIDiv{} = (Infix, "`div`")
operator OMod{} = (Infix, "`mod`")
ppSTy :: Int -> STy t -> String
-ppSTy d ty = ppTy d (unSTy ty)
+ppSTy d ty = render $ ppSTy' d ty
ppSTy' :: Int -> STy t -> Doc q
-ppSTy' d ty = ppTy' d (unSTy ty)
-
-ppTy :: Int -> Ty -> String
-ppTy d ty = render $ ppTy' d ty
-
-ppTy' :: Int -> Ty -> Doc q
-ppTy' _ TNil = ppString "1"
-ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
-ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
-ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
-ppTy' d (TArr n t) = ppParen (d > 10) $
- ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t
-ppTy' _ (TScal sty) = ppString $ case sty of
- TI32 -> "i32"
- TI64 -> "i64"
- TF32 -> "f32"
- TF64 -> "f64"
- TBool -> "bool"
-ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
+ppSTy' _ STNil = ppString "1"
+ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
+ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
+ppSTy' d (STArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
+ppSTy' _ (STScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
+ STBool -> "bool"
+ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t
ppString :: String -> Doc x
ppString = fromString
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index 217b2f5..b20fc2d 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -1,34 +1,34 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeData #-}
module AST.Types where
import Data.Int (Int32, Int64)
+import Data.GADT.Compare
import Data.GADT.Show
import Data.Kind (Type)
-import Data.Some
import Data.Type.Equality
import Data
-data Ty
+type data Ty
= TNil
| TPair Ty Ty
| TEither Ty Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TAccum Ty -- ^ the accumulator contains D2 of this type
- deriving (Show, Eq, Ord)
+ | TAccum Ty -- ^ contained type must be a monoid type
-data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
- deriving (Show, Eq, Ord)
+type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
type STy :: Ty -> Type
data STy t where
@@ -41,22 +41,25 @@ data STy t where
STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)
-instance TestEquality STy where
- testEquality STNil STNil = Just Refl
- testEquality STNil _ = Nothing
- testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality STPair{} _ = Nothing
- testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality STEither{} _ = Nothing
- testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
- testEquality STMaybe{} _ = Nothing
- testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality STArr{} _ = Nothing
- testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
- testEquality STScal{} _ = Nothing
- testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
- testEquality STAccum{} _ = Nothing
-
+instance GCompare STy where
+ gcompare = \cases
+ STNil STNil -> GEQ
+ STNil _ -> GLT ; _ STNil -> GGT
+ (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STPair{} _ -> GLT ; _ STPair{} -> GGT
+ (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
+ STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
+ (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ STArr{} _ -> GLT ; _ STArr{} -> GGT
+ (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
+ STScal{} _ -> GLT ; _ STScal{} -> GGT
+ (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+
+instance TestEquality STy where testEquality = geq
+instance GEq STy where geq = defaultGeq
instance GShow STy where gshowsPrec = defaultGshowsPrec
data SScalTy t where
@@ -67,14 +70,21 @@ data SScalTy t where
STBool :: SScalTy TBool
deriving instance Show (SScalTy t)
-instance TestEquality SScalTy where
- testEquality STI32 STI32 = Just Refl
- testEquality STI64 STI64 = Just Refl
- testEquality STF32 STF32 = Just Refl
- testEquality STF64 STF64 = Just Refl
- testEquality STBool STBool = Just Refl
- testEquality _ _ = Nothing
-
+instance GCompare SScalTy where
+ gcompare = \cases
+ STI32 STI32 -> GEQ
+ STI32 _ -> GLT ; _ STI32 -> GGT
+ STI64 STI64 -> GEQ
+ STI64 _ -> GLT ; _ STI64 -> GGT
+ STF32 STF32 -> GEQ
+ STF32 _ -> GLT ; _ STF32 -> GGT
+ STF64 STF64 -> GEQ
+ STF64 _ -> GLT ; _ STF64 -> GGT
+ STBool STBool -> GEQ
+ -- STBool _ -> GLT ; _ STBool -> GGT
+
+instance TestEquality SScalTy where testEquality = geq
+instance GEq SScalTy where geq = defaultGeq
instance GShow SScalTy where gshowsPrec = defaultGshowsPrec
scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
@@ -89,50 +99,6 @@ type TIx = TScal TI64
tIx :: STy TIx
tIx = STScal STI64
-unSTy :: STy t -> Ty
-unSTy = \case
- STNil -> TNil
- STPair a b -> TPair (unSTy a) (unSTy b)
- STEither a b -> TEither (unSTy a) (unSTy b)
- STMaybe t -> TMaybe (unSTy t)
- STArr n t -> TArr (unSNat n) (unSTy t)
- STScal t -> TScal (unSScalTy t)
- STAccum t -> TAccum (unSTy t)
-
-unSEnv :: SList STy env -> [Ty]
-unSEnv SNil = []
-unSEnv (SCons t l) = unSTy t : unSEnv l
-
-unSScalTy :: SScalTy t -> ScalTy
-unSScalTy = \case
- STI32 -> TI32
- STI64 -> TI64
- STF32 -> TF32
- STF64 -> TF64
- STBool -> TBool
-
-reSTy :: Ty -> Some STy
-reSTy = \case
- TNil -> Some STNil
- TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b'
- TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b'
- TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t'
- TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t'
- TScal t | Some t' <- reSScalTy t -> Some $ STScal t'
- TAccum t | Some t' <- reSTy t -> Some $ STAccum t'
-
-reSEnv :: [Ty] -> Some (SList STy)
-reSEnv [] = Some SNil
-reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env)
-
-reSScalTy :: ScalTy -> Some SScalTy
-reSScalTy = \case
- TI32 -> Some STI32
- TI64 -> Some STI64
- TF32 -> Some STF32
- TF64 -> Some STF64
- TBool -> Some STBool
-
type family ScalRep t where
ScalRep TI32 = Int32
ScalRep TI64 = Int64
@@ -161,7 +127,7 @@ type family ScalIsIntegral t where
ScalIsIntegral TF64 = False
ScalIsIntegral TBool = False
--- | Returns true for arrays /and/ accumulators;
+-- | Returns true for arrays /and/ accumulators.
hasArrays :: STy t' -> Bool
hasArrays STNil = False
hasArrays (STPair a b) = hasArrays a || hasArrays b
diff --git a/src/Compile.hs b/src/Compile.hs
index e3eb207..e2d004a 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -282,11 +282,11 @@ genStructs ty = do
tell (BList (genStruct name ty))
-genAllStructs :: Foldable t => t Ty -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty
+genAllStructs :: Foldable t => t (Some STy) -> [StructDecl]
+genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty
data CompState = CompState
- { csStructs :: Set Ty
+ { csStructs :: Set (Some STy)
, csTopLevelDecls :: Bag String
, csStmts :: Bag Stmt
, csNextId :: Int }
@@ -329,7 +329,7 @@ scope m = do
emitStruct :: STy t -> CompM String
emitStruct ty = CompM $ do
- modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) }
+ modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) }
return (genStructName ty)
emitTLD :: String -> CompM ()
@@ -348,7 +348,7 @@ compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets)
compileToString codeID env expr =
let args = nameEnv env
(res, s) = runCompM (compile' args expr)
- structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
+ structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env))
(arg_pairs, arg_metrics) =
unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
diff --git a/src/Data.hs b/src/Data.hs
index e7b3148..e86aaa6 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -11,6 +11,8 @@
module Data (module Data, (:~:)(Refl)) where
import Data.Functor.Product
+import Data.GADT.Compare
+import Data.GADT.Show
import Data.Some
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
@@ -73,10 +75,15 @@ data SNat n where
SS :: SNat n -> SNat (S n)
deriving instance Show (SNat n)
-instance TestEquality SNat where
- testEquality SZ SZ = Just Refl
- testEquality (SS n) (SS n') | Just Refl <- testEquality n n' = Just Refl
- testEquality _ _ = Nothing
+instance GCompare SNat where
+ gcompare SZ SZ = GEQ
+ gcompare SZ _ = GLT
+ gcompare _ SZ = GGT
+ gcompare (SS n) (SS n') = gorderingLift1 (gcompare n n')
+
+instance TestEquality SNat where testEquality = geq
+instance GEq SNat where geq = defaultGeq
+instance GShow SNat where gshowsPrec = defaultGshowsPrec
fromSNat :: SNat n -> Int
fromSNat SZ = 0
@@ -90,10 +97,6 @@ reSNat :: Nat -> Some SNat
reSNat Z = Some SZ
reSNat (S n) | Some n' <- reSNat n = Some (SS n')
-fromNat :: Nat -> Int
-fromNat Z = 0
-fromNat (S m) = succ (fromNat m)
-
class KnownNat n where knownNat :: SNat n
instance KnownNat Z where knownNat = SZ
instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
@@ -155,6 +158,18 @@ vecInit (x :< xs@(_ :< _)) = x :< vecInit xs
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = unsafeCoerce Refl
+gorderingLift1 :: GOrdering a a' -> GOrdering (f a) (f a')
+gorderingLift1 GLT = GLT
+gorderingLift1 GGT = GGT
+gorderingLift1 GEQ = GEQ
+
+gorderingLift2 :: GOrdering a a' -> GOrdering b b' -> GOrdering (f a b) (f a' b')
+gorderingLift2 GLT _ = GLT
+gorderingLift2 GGT _ = GGT
+gorderingLift2 GEQ GLT = GLT
+gorderingLift2 GEQ GGT = GGT
+gorderingLift2 GEQ GEQ = GEQ
+
data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t]
deriving (Show, Functor, Foldable, Traversable)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index e0ab37b..ea3bb95 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -129,11 +129,36 @@ simplify' = \case
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
- -- TODO: array indexing (index of build, index of fold)
+ -- TODO: more array indexing
+ EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
+ EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
- -- TODO: beta rules for maybe
+ -- TODO: more constant folding
+ EOp _ OIf (EConst _ STBool True) -> (Any True, EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> (Any True, EInr ext STNil (ENil ext))
- -- TODO: constant folding for operations
+ -- inline cheap array constructors
+ ELet _ (EReplicate1Inner _ e1 e2) e3 ->
+ acted $ simplify' $
+ ELet ext (EPair ext e1 e2) $
+ let v = EVar ext (STPair tIx (typeOf e2)) IZ
+ in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3
+ -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is
+ -- -- cheap, which it can't be because (!) is not cheap if you do AD after.
+ -- -- Should do proper SoA representation.
+ -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 ->
+ -- acted $ simplify' $
+ -- ELet ext e1 $
+ -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3
+
+ -- eta rule for unit
+ e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
+ case e of
+ ENil _ -> (Any False, e)
+ _ -> (Any True, ENil ext)
+
+ EBuild _ SZ _ e ->
+ acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
EAccum _ t p e1 e2 acc -> do
@@ -222,6 +247,7 @@ cheapExpr = \case
EConst{} -> True
EFst _ e -> cheapExpr e
ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
_ -> False
-- | This can be made more precise by tracking (and not counting) adds on