summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal7
-rw-r--r--src/AST.hs40
-rw-r--r--src/AST/Pretty.hs38
-rw-r--r--src/Data.hs15
-rw-r--r--src/Example.hs82
-rw-r--r--src/Language.hs104
-rw-r--r--src/Language/AST.hs134
-rw-r--r--src/Language/Tag.hs22
8 files changed, 363 insertions, 79 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 1bff84b..0c9170c 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -20,14 +20,17 @@ library
-- Compile
Data
Example
+ Language
+ Language.AST
+ Language.Tag
Lemmas
- PreludeCu
+ -- PreludeCu
Simplify
other-modules:
build-depends:
base >= 4.19 && < 4.21,
containers,
- template-haskell,
+ -- template-haskell,
transformers,
hs-source-dirs:
src
diff --git a/src/AST.hs b/src/AST.hs
index f389467..785e34a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -18,6 +18,7 @@ module AST (module AST, module AST.Weaken) where
import Data.Functor.Const
import Data.Kind (Type)
import Data.Int
+import Data.Type.Equality
import AST.Env
import AST.Weaken
@@ -46,6 +47,15 @@ data STy t where
STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)
+instance TestEquality STy where
+ testEquality STNil STNil = Just Refl
+ testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality _ _ = Nothing
+
data SScalTy t where
STI32 :: SScalTy TI32
STI64 :: SScalTy TI64
@@ -54,6 +64,21 @@ data SScalTy t where
STBool :: SScalTy TBool
deriving instance Show (SScalTy t)
+instance TestEquality SScalTy where
+ testEquality STI32 STI32 = Just Refl
+ testEquality STI64 STI64 = Just Refl
+ testEquality STF32 STF32 = Just Refl
+ testEquality STF64 STF64 = Just Refl
+ testEquality STBool STBool = Just Refl
+ testEquality _ _ = Nothing
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
type TIx = TScal TI64
tIx :: STy TIx
@@ -305,6 +330,21 @@ class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+styKnown :: STy t -> Dict (KnownTy t)
+styKnown STNil = Dict
+styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
+styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
+styKnown (STAccum t) | Dict <- styKnown t = Dict
+
+sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
+sscaltyKnown STI32 = Dict
+sscaltyKnown STI64 = Dict
+sscaltyKnown STF32 = Dict
+sscaltyKnown STF64 = Dict
+sscaltyKnown STBool = Dict
+
ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
ebuildUp1 n sh size f =
EBuild ext (SS n) (EPair ext sh size) $
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 5610d36..bf0d350 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -17,16 +17,7 @@ import AST.Count
import Data
-data Val f env where
- VTop :: Val f '[]
- VPush :: f t -> Val f env -> Val f (t : env)
-
-type SVal = Val (Const String)
-
-valprj :: Val f env -> Idx env t -> f t
-valprj (VPush x _) IZ = x
-valprj (VPush _ env) (IS i) = valprj env i
-valprj VTop i = case i of {}
+type SVal = SList (Const String)
newtype M a = M { runM :: Int -> (a, Int) }
deriving (Functor)
@@ -51,15 +42,20 @@ genNameIfUsedIn' prefix ty idx ex
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = genNameIfUsedIn' "x"
+valprj :: SList f env -> Idx env t -> f t
+valprj (x `SCons` _) IZ = x
+valprj (_ `SCons` env) (IS i) = valprj env i
+valprj SNil i = case i of {}
+
ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
where
mkVal :: SList STy env -> M (SVal env)
- mkVal SNil = return VTop
+ mkVal SNil = return SNil
mkVal (SCons _ v) = do
val <- mkVal v
name <- genName
- return (VPush (Const name) val)
+ return (Const name `SCons` val)
ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
ppExpr' d val = \case
@@ -94,9 +90,9 @@ ppExpr' d val = \case
e' <- ppExpr' 0 val e
let STEither t1 t2 = typeOf e
name1 <- genNameIfUsedIn t1 IZ a
- a' <- ppExpr' 0 (VPush (Const name1) val) a
+ a' <- ppExpr' 0 (Const name1 `SCons` val) a
name2 <- genNameIfUsedIn t2 IZ b
- b' <- ppExpr' 0 (VPush (Const name2) val) b
+ b' <- ppExpr' 0 (Const name2 `SCons` val) b
return $ showParen (d > 0) $
showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
. showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"
@@ -104,21 +100,21 @@ ppExpr' d val = \case
EBuild1 _ a b -> do
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn (STScal STI64) IZ b
- b' <- ppExpr' 0 (VPush (Const name) val) b
+ b' <- ppExpr' 0 (Const name `SCons` val) b
return $ showParen (d > 10) $
showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b
- e' <- ppExpr' 0 (VPush (Const name) val) b
+ e' <- ppExpr' 0 (Const name `SCons` val) b
return $ showParen (d > 10) $
showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
EFold1 _ a b -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
- a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
+ a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $
showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
@@ -142,13 +138,13 @@ ppExpr' d val = \case
EIdx1 _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 9 val b
- return $ showParen (d > 8) $ a' . showString " ! " . b'
+ return $ showParen (d > 8) $ a' . showString " .! " . b'
EIdx _ _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 10 val b
return $ showParen (d > 8) $
- a' . showString " !! " . b'
+ a' . showString " ! " . b'
EShape _ e -> do
e' <- ppExpr' 11 val e
@@ -170,7 +166,7 @@ ppExpr' d val = \case
EWith e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
- e2' <- ppExpr' 0 (VPush (Const name) val) e2
+ e2' <- ppExpr' 0 (Const name `SCons` val) e2
return $ showParen (d > 10) $
showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
. e2' . showString ")"
@@ -191,7 +187,7 @@ ppExprLet d val etop = do
let occ = occCount IZ body
name <- genNameIfUsedIn (typeOf rhs) IZ body
rhs' <- ppExpr' 0 val' rhs
- (binds, core) <- collect (VPush (Const name) val') body
+ (binds, core) <- collect (Const name `SCons` val') body
return ((name, occ, rhs') : binds, core)
collect val' e = ([],) <$> ppExpr' 0 val' e
diff --git a/src/Data.hs b/src/Data.hs
index eb6c033..840cb88 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -9,9 +9,15 @@
{-# LANGUAGE TypeOperators #-}
module Data where
+import Data.Type.Equality
+
import Lemmas (Append)
+data Dict c where
+ Dict :: c => Dict c
+
+
data SList f l where
SNil :: SList f '[]
SCons :: f a -> SList f l -> SList f (a : l)
@@ -42,6 +48,11 @@ data SNat n where
SS :: SNat n -> SNat (S n)
deriving instance Show (SNat n)
+instance TestEquality SNat where
+ testEquality SZ SZ = Just Refl
+ testEquality (SS n) (SS n') | Just Refl <- testEquality n n' = Just Refl
+ testEquality _ _ = Nothing
+
fromSNat :: SNat n -> Int
fromSNat SZ = 0
fromSNat (SS n) = succ (fromSNat n)
@@ -50,6 +61,10 @@ class KnownNat n where knownNat :: SNat n
instance KnownNat Z where knownNat = SZ
instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+snatKnown :: SNat n -> Dict (KnownNat n)
+snatKnown SZ = Dict
+snatKnown (SS n) | Dict <- snatKnown n = Dict
+
data Vec n t where
VNil :: Vec Z t
(:<) :: t -> Vec n t -> Vec (S n) t
diff --git a/src/Example.hs b/src/Example.hs
index 424351c..6fd19cd 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -6,6 +6,7 @@ import AST
import AST.Pretty
import CHAD
import Data
+import Language
import Simplify
@@ -51,46 +52,24 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
-- one "v2" (x1 * x5)
-- one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex1 =
- bin (OAdd STF32)
- (bin (OMul STF32)
- (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ))
- (EVar ext (STScal STF32) (IS IZ))
+ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ x * y + x
-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex2 =
- ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ)) $
- bin (OMul STF32)
- (EVar ext (STScal STF32) IZ)
- (bin (OAdd STF32)
- (EVar ext (STScal STF32) IZ)
- (EVar ext (STScal STF32) (IS (IS IZ))))
+ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ let_ (x + y) $ \z ->
+ z * (z + x)
-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex3 =
- ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ)))
- (bin (OMul STF32) (EConst ext STF32 2.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
- (bin (OAdd STF32) (EConst ext STF32 3.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
+ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ if_ (x .< y) (2 * x) (3 * x)
-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex4 =
- ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ)))
- (bin (OAdd STF32)
- (bin (OMul STF32) (EConst ext STF32 2.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
- (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) (IS IZ))))
- (bin (OAdd STF32) (EConst ext STF32 3.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
+ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ if_ (x .< y) (2 * x + y * y) (3 + x)
senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil
@@ -101,13 +80,9 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S
-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
-ex5 =
- ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ))
- (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
- (EVar ext (STScal STF32) (IS IZ)))
- (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
- (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
- (EConst ext STF32 1.0)))
+ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ case_ x (\a -> a * y)
+ (\b -> b * (y + 1))
senv6 :: SList STy [TScal TI64, TScal TF32]
senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil
@@ -119,13 +94,10 @@ descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)
-- b = build1 n (\i. let c = idx0 a in c * c)
-- in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
-ex6 =
- ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $
- ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $
- ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $
- bin (OMul STF32) (EVar ext (STScal STF32) IZ)
- (EVar ext (STScal STF32) IZ)) $
- (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3)))
+ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $
+ let_ (unit x) $ \a ->
+ let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b ->
+ idx0 (b .! 3)
type R = TScal TF32
@@ -154,19 +126,17 @@ descr7 =
-- x3 = fst p3 * x + snd p3
-- in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
-ex7 =
+ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $
let tR = STScal STF32
tpair = STPair tR tR
- layer :: STy p -> Idx env p -> Idx env R -> Ex env R
- layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp =
- ELet ext (ESnd ext (EVar ext parst pars)) $
- ELet ext (EFst ext (EVar ext parst (IS pars))) $
- ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ)))
- (EVar ext tR (IS (IS inp))))
- (ESnd ext (EVar ext tpair (IS IZ)))) $
- layer t (IS IZ) IZ
- layer STNil _ inp = EVar ext tR inp
+ layer :: STy p -> SExpr p -> SExpr R -> SExpr R
+ layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t =
+ let_ (snd_ pars) $ \par ->
+ let_ (fst_ pars) $ \restpars ->
+ let_ (fst_ par * inp + snd_ par) $ \res ->
+ layer t restpars res
+ layer STNil _ inp = inp
layer _ _ _ = error "Invalid layer inputs"
- in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ
+ in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input
diff --git a/src/Language.hs b/src/Language.hs
new file mode 100644
index 0000000..b76e07f
--- /dev/null
+++ b/src/Language.hs
@@ -0,0 +1,104 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitForAll #-}
+{-# LANGUAGE TypeOperators #-}
+module Language (
+ scopeCheck,
+ SExpr,
+ module Language,
+) where
+
+import AST
+import Data
+import Language.AST
+
+
+lambda :: forall a args t. KnownTy a => (SExpr a -> SFun args t) -> SFun (Append args '[a]) t
+lambda f = case mkLambda f f of
+ Lambda tag (SFun args e) ->
+ SFun (sappend args (tag `SCons` SNil)) e
+
+body :: SExpr t -> SFun '[] t
+body e = SFun SNil e
+
+
+let_ :: KnownTy a => SExpr a -> (SExpr a -> SExpr t) -> SExpr t
+let_ rhs f = SELet rhs (mkLambda (rhs, f) f)
+
+pair :: SExpr a -> SExpr b -> SExpr (TPair a b)
+pair = SEPair
+
+fst_ :: SExpr (TPair a b) -> SExpr a
+fst_ = SEFst
+
+snd_ :: SExpr (TPair a b) -> SExpr b
+snd_ = SESnd
+
+nil :: SExpr TNil
+nil = SENil
+
+inl :: STy b -> SExpr a -> SExpr (TEither a b)
+inl = SEInl
+
+inr :: STy a -> SExpr b -> SExpr (TEither a b)
+inr = SEInr
+
+case_ :: (KnownTy a, KnownTy b)
+ => SExpr (TEither a b) -> (SExpr a -> SExpr c) -> (SExpr b -> SExpr c) -> SExpr c
+case_ e f g = SECase e (mkLambda (e, f) f) (mkLambda (e, g) g)
+
+build1 :: SExpr TIx -> (SExpr TIx -> SExpr t) -> SExpr (TArr (S Z) t)
+build1 e f = SEBuild1 e (mkLambda (e, f) f)
+
+build :: SNat n -> SExpr (Tup (Replicate n TIx)) -> (SExpr (Tup (Replicate n TIx)) -> SExpr t) -> SExpr (TArr n t)
+build n e f = SEBuild n e (mkLambda' (e, f) (tTup (sreplicate n tIx)) f)
+
+fold1 :: KnownTy t => (SExpr t -> SExpr t -> SExpr t) -> SExpr (TArr (S n) t) -> SExpr (TArr n t)
+fold1 f e = SEFold1 (mkLambda2 (f, e) f) e
+
+unit :: SExpr t -> SExpr (TArr Z t)
+unit = SEUnit
+
+const_ :: KnownScalTy t => ScalRep t -> SExpr (TScal t)
+const_ x =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> SEConst ty x
+
+idx0 :: SExpr (TArr Z t) -> SExpr t
+idx0 = SEIdx0
+
+(.!) :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t)
+(.!) = SEIdx1
+
+(!) :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t
+(!) = SEIdx
+
+shape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx))
+shape = SEShape
+
+oper :: SOp a t -> SExpr a -> SExpr t
+oper = SEOp
+
+error_ :: KnownTy t => String -> SExpr t
+error_ s = SEError knownTy s
+
+(.==) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool)
+a .== b = oper (OEq knownScalTy) (pair a b)
+
+(.<) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool)
+a .< b = oper (OLt knownScalTy) (pair a b)
+
+(.>) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool)
+(.>) = flip (.<)
+
+(.<=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool)
+a .<= b = oper (OLe knownScalTy) (pair a b)
+
+(.>=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool)
+(.>=) = flip (.<=)
+
+not_ :: SExpr (TScal TBool) -> SExpr (TScal TBool)
+not_ = oper ONot
+
+if_ :: SExpr (TScal TBool) -> SExpr t -> SExpr t -> SExpr t
+if_ e a b = case_ (oper OIf e) (\_ -> a) (\_ -> b)
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
new file mode 100644
index 0000000..1c53c8a
--- /dev/null
+++ b/src/Language/AST.hs
@@ -0,0 +1,134 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module Language.AST where
+
+import AST
+import Data
+import Data.Type.Equality
+import Language.Tag
+
+
+data SExpr t where
+ -- lambda calculus
+ SEVar :: Tag t -> SExpr t
+ SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t
+
+ -- base types
+ SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b)
+ SEFst :: SExpr (TPair a b) -> SExpr a
+ SESnd :: SExpr (TPair a b) -> SExpr b
+ SENil :: SExpr TNil
+ SEInl :: STy b -> SExpr a -> SExpr (TEither a b)
+ SEInr :: STy a -> SExpr b -> SExpr (TEither a b)
+ SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c
+
+ -- array operations
+ SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t)
+ SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t)
+ SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t)
+ SEUnit :: SExpr t -> SExpr (TArr Z t)
+
+ -- expression operations
+ SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t)
+ SEIdx0 :: SExpr (TArr Z t) -> SExpr t
+ SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t)
+ SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t
+ SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx))
+ SEOp :: SOp a t -> SExpr a -> SExpr t
+
+ -- partiality
+ SEError :: STy a -> String -> SExpr a
+deriving instance Show (SExpr t)
+
+data Lambda a b = Lambda (Tag a) b
+ deriving (Show)
+
+mkLambda :: KnownTy a => handle -> (SExpr a -> f t) -> Lambda a (f t)
+mkLambda handle f = mkLambda' handle knownTy f
+
+mkLambda' :: handle -> STy a -> (SExpr a -> f t) -> Lambda a (f t)
+mkLambda' handle ty f =
+ let tag = genTag handle ty
+ in Lambda tag (f (SEVar tag))
+
+mkLambda2 :: (KnownTy a, KnownTy b)
+ => handle -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
+mkLambda2 handle f = mkLambda2' handle knownTy knownTy f
+
+mkLambda2' :: handle -> STy a -> STy b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
+mkLambda2' handle ty1 ty2 f =
+ let tag2 = genTag handle ty2
+ lam2 = Lambda tag2 (f (SEVar tag1) (SEVar tag2))
+ tag1 = genTag lam2 ty1
+ in Lambda tag1 lam2
+
+instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where
+ a + b = SEOp (OAdd knownScalTy) (SEPair a b)
+ a * b = SEOp (OMul knownScalTy) (SEPair a b)
+ negate e = SEOp (ONeg knownScalTy) e
+ abs = error "abs undefined for SExpr"
+ signum = error "signum undefined for SExpr"
+ fromInteger =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> SEConst ty . fromInteger
+
+data SFun args t = SFun (SList Tag args) (SExpr t)
+
+scopeCheck :: SFun env t -> Ex env t
+scopeCheck (SFun args e) = scopeCheckExpr args e
+
+scopeCheckExpr :: forall env t. SList Tag env -> SExpr t -> Ex env t
+scopeCheckExpr val = \case
+ SEVar tag@(Tag ty _)
+ | Just idx <- find tag val -> EVar ext ty idx
+ | otherwise -> error "Variable out of scope in conversion from surface \
+ \expression to De Bruijn expression"
+ SELet a b -> ELet ext (go a) (lambda val b)
+
+ SEPair a b -> EPair ext (go a) (go b)
+ SEFst e -> EFst ext (go e)
+ SESnd e -> ESnd ext (go e)
+ SENil -> ENil ext
+ SEInl t e -> EInl ext t (go e)
+ SEInr t e -> EInr ext t (go e)
+ SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b)
+
+ SEBuild1 a b -> EBuild1 ext (go a) (lambda val b)
+ SEBuild n a b -> EBuild ext n (go a) (lambda val b)
+ SEFold1 a b -> EFold1 ext (lambda2 val a) (go b)
+ SEUnit e -> EUnit ext (go e)
+
+ SEConst t x -> EConst ext t x
+ SEIdx0 e -> EIdx0 ext (go e)
+ SEIdx1 a b -> EIdx1 ext (go a) (go b)
+ SEIdx n a b -> EIdx ext n (go a) (go b)
+ SEShape e -> EShape ext (go e)
+ SEOp op e -> EOp ext op (go e)
+
+ SEError t s -> EError t s
+ where
+ go :: SExpr t' -> Ex env t'
+ go = scopeCheckExpr val
+
+ find :: Tag t' -> SList Tag env' -> Maybe (Idx env' t')
+ find _ SNil = Nothing
+ find tag@(Tag ty i) (Tag ty' i' `SCons` val')
+ | i == i'
+ , Just Refl <- testEquality ty ty'
+ = Just IZ
+ | otherwise
+ = IS <$> find tag val'
+
+ lambda :: SList Tag env' -> Lambda a (SExpr b) -> Ex (a : env') b
+ lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e
+
+ lambda2 :: SList Tag env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c
+ lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e
diff --git a/src/Language/Tag.hs b/src/Language/Tag.hs
new file mode 100644
index 0000000..9356073
--- /dev/null
+++ b/src/Language/Tag.hs
@@ -0,0 +1,22 @@
+{-# LANGUAGE BangPatterns #-}
+module Language.Tag (
+ Tag(..), genTag,
+) where
+
+import Data.IORef
+import System.IO.Unsafe
+
+import AST
+
+
+data Tag t = Tag (STy t) Int
+ deriving (Show)
+
+{-# NOINLINE tagCounter #-}
+tagCounter :: IORef Int
+tagCounter = unsafePerformIO $ newIORef 1
+
+{-# NOINLINE genTag #-}
+genTag :: handle -> STy t -> Tag t
+genTag !_ ty =
+ unsafePerformIO $ Tag ty <$> atomicModifyIORef' tagCounter (\i -> (succ i, i))