diff options
-rw-r--r-- | chad-fast.cabal | 7 | ||||
-rw-r--r-- | src/AST.hs | 40 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 38 | ||||
-rw-r--r-- | src/Data.hs | 15 | ||||
-rw-r--r-- | src/Example.hs | 82 | ||||
-rw-r--r-- | src/Language.hs | 104 | ||||
-rw-r--r-- | src/Language/AST.hs | 134 | ||||
-rw-r--r-- | src/Language/Tag.hs | 22 |
8 files changed, 363 insertions, 79 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 1bff84b..0c9170c 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -20,14 +20,17 @@ library -- Compile Data Example + Language + Language.AST + Language.Tag Lemmas - PreludeCu + -- PreludeCu Simplify other-modules: build-depends: base >= 4.19 && < 4.21, containers, - template-haskell, + -- template-haskell, transformers, hs-source-dirs: src @@ -18,6 +18,7 @@ module AST (module AST, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) import Data.Int +import Data.Type.Equality import AST.Env import AST.Weaken @@ -46,6 +47,15 @@ data STy t where STAccum :: STy t -> STy (TAccum t) deriving instance Show (STy t) +instance TestEquality STy where + testEquality STNil STNil = Just Refl + testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl + testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl + testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl + testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl + testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl + testEquality _ _ = Nothing + data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 @@ -54,6 +64,21 @@ data SScalTy t where STBool :: SScalTy TBool deriving instance Show (SScalTy t) +instance TestEquality SScalTy where + testEquality STI32 STI32 = Just Refl + testEquality STI64 STI64 = Just Refl + testEquality STF32 STF32 = Just Refl + testEquality STF64 STF64 = Just Refl + testEquality STBool STBool = Just Refl + testEquality _ _ = Nothing + +scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) +scalRepIsShow STI32 = Dict +scalRepIsShow STI64 = Dict +scalRepIsShow STF32 = Dict +scalRepIsShow STF64 = Dict +scalRepIsShow STBool = Dict + type TIx = TScal TI64 tIx :: STy TIx @@ -305,6 +330,21 @@ class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv +styKnown :: STy t -> Dict (KnownTy t) +styKnown STNil = Dict +styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict +styKnown (STScal t) | Dict <- sscaltyKnown t = Dict +styKnown (STAccum t) | Dict <- styKnown t = Dict + +sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) +sscaltyKnown STI32 = Dict +sscaltyKnown STI64 = Dict +sscaltyKnown STF32 = Dict +sscaltyKnown STF64 = Dict +sscaltyKnown STBool = Dict + ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 5610d36..bf0d350 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -17,16 +17,7 @@ import AST.Count import Data -data Val f env where - VTop :: Val f '[] - VPush :: f t -> Val f env -> Val f (t : env) - -type SVal = Val (Const String) - -valprj :: Val f env -> Idx env t -> f t -valprj (VPush x _) IZ = x -valprj (VPush _ env) (IS i) = valprj env i -valprj VTop i = case i of {} +type SVal = SList (Const String) newtype M a = M { runM :: Int -> (a, Int) } deriving (Functor) @@ -51,15 +42,20 @@ genNameIfUsedIn' prefix ty idx ex genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" +valprj :: SList f env -> Idx env t -> f t +valprj (x `SCons` _) IZ = x +valprj (_ `SCons` env) (IS i) = valprj env i +valprj SNil i = case i of {} + ppExpr :: SList STy env -> Expr x env t -> String ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" where mkVal :: SList STy env -> M (SVal env) - mkVal SNil = return VTop + mkVal SNil = return SNil mkVal (SCons _ v) = do val <- mkVal v name <- genName - return (VPush (Const name) val) + return (Const name `SCons` val) ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS ppExpr' d val = \case @@ -94,9 +90,9 @@ ppExpr' d val = \case e' <- ppExpr' 0 val e let STEither t1 t2 = typeOf e name1 <- genNameIfUsedIn t1 IZ a - a' <- ppExpr' 0 (VPush (Const name1) val) a + a' <- ppExpr' 0 (Const name1 `SCons` val) a name2 <- genNameIfUsedIn t2 IZ b - b' <- ppExpr' 0 (VPush (Const name2) val) b + b' <- ppExpr' 0 (Const name2 `SCons` val) b return $ showParen (d > 0) $ showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" @@ -104,21 +100,21 @@ ppExpr' d val = \case EBuild1 _ a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn (STScal STI64) IZ b - b' <- ppExpr' 0 (VPush (Const name) val) b + b' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" EBuild _ n a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b - e' <- ppExpr' 0 (VPush (Const name) val) b + e' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" EFold1 _ a b -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a + a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' @@ -142,13 +138,13 @@ ppExpr' d val = \case EIdx1 _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b - return $ showParen (d > 8) $ a' . showString " ! " . b' + return $ showParen (d > 8) $ a' . showString " .! " . b' EIdx _ _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 10 val b return $ showParen (d > 8) $ - a' . showString " !! " . b' + a' . showString " ! " . b' EShape _ e -> do e' <- ppExpr' 11 val e @@ -170,7 +166,7 @@ ppExpr' d val = \case EWith e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 - e2' <- ppExpr' 0 (VPush (Const name) val) e2 + e2' <- ppExpr' 0 (Const name `SCons` val) e2 return $ showParen (d > 10) $ showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") . e2' . showString ")" @@ -191,7 +187,7 @@ ppExprLet d val etop = do let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body rhs' <- ppExpr' 0 val' rhs - (binds, core) <- collect (VPush (Const name) val') body + (binds, core) <- collect (Const name `SCons` val') body return ((name, occ, rhs') : binds, core) collect val' e = ([],) <$> ppExpr' 0 val' e diff --git a/src/Data.hs b/src/Data.hs index eb6c033..840cb88 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -9,9 +9,15 @@ {-# LANGUAGE TypeOperators #-} module Data where +import Data.Type.Equality + import Lemmas (Append) +data Dict c where + Dict :: c => Dict c + + data SList f l where SNil :: SList f '[] SCons :: f a -> SList f l -> SList f (a : l) @@ -42,6 +48,11 @@ data SNat n where SS :: SNat n -> SNat (S n) deriving instance Show (SNat n) +instance TestEquality SNat where + testEquality SZ SZ = Just Refl + testEquality (SS n) (SS n') | Just Refl <- testEquality n n' = Just Refl + testEquality _ _ = Nothing + fromSNat :: SNat n -> Int fromSNat SZ = 0 fromSNat (SS n) = succ (fromSNat n) @@ -50,6 +61,10 @@ class KnownNat n where knownNat :: SNat n instance KnownNat Z where knownNat = SZ instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat +snatKnown :: SNat n -> Dict (KnownNat n) +snatKnown SZ = Dict +snatKnown (SS n) | Dict <- snatKnown n = Dict + data Vec n t where VNil :: Vec Z t (:<) :: t -> Vec n t -> Vec (S n) t diff --git a/src/Example.hs b/src/Example.hs index 424351c..6fd19cd 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -6,6 +6,7 @@ import AST import AST.Pretty import CHAD import Data +import Language import Simplify @@ -51,46 +52,24 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex1 = - bin (OAdd STF32) - (bin (OMul STF32) - (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ)) - (EVar ext (STScal STF32) (IS IZ)) +ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + x * y + x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex2 = - ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ)) $ - bin (OMul STF32) - (EVar ext (STScal STF32) IZ) - (bin (OAdd STF32) - (EVar ext (STScal STF32) IZ) - (EVar ext (STScal STF32) (IS (IS IZ)))) +ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + let_ (x + y) $ \z -> + z * (z + x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex3 = - ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ))) - (bin (OMul STF32) (EConst ext STF32 2.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) - (bin (OAdd STF32) (EConst ext STF32 3.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) +ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + if_ (x .< y) (2 * x) (3 * x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex4 = - ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ))) - (bin (OAdd STF32) - (bin (OMul STF32) (EConst ext STF32 2.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) - (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) (IS IZ)))) - (bin (OAdd STF32) (EConst ext STF32 3.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) +ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + if_ (x .< y) (2 * x + y * y) (3 + x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil @@ -101,13 +80,9 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) -ex5 = - ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ)) - (bin (OMul STF32) (EVar ext (STScal STF32) IZ) - (EVar ext (STScal STF32) (IS IZ))) - (bin (OMul STF32) (EVar ext (STScal STF32) IZ) - (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) - (EConst ext STF32 1.0))) +ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + case_ x (\a -> a * y) + (\b -> b * (y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil @@ -119,13 +94,10 @@ descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) -ex6 = - ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ - ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $ - ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $ - bin (OMul STF32) (EVar ext (STScal STF32) IZ) - (EVar ext (STScal STF32) IZ)) $ - (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) +ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $ + let_ (unit x) $ \a -> + let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b -> + idx0 (b .! 3) type R = TScal TF32 @@ -154,19 +126,17 @@ descr7 = -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R -ex7 = +ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $ let tR = STScal STF32 tpair = STPair tR tR - layer :: STy p -> Idx env p -> Idx env R -> Ex env R - layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp = - ELet ext (ESnd ext (EVar ext parst pars)) $ - ELet ext (EFst ext (EVar ext parst (IS pars))) $ - ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ))) - (EVar ext tR (IS (IS inp)))) - (ESnd ext (EVar ext tpair (IS IZ)))) $ - layer t (IS IZ) IZ - layer STNil _ inp = EVar ext tR inp + layer :: STy p -> SExpr p -> SExpr R -> SExpr R + layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t = + let_ (snd_ pars) $ \par -> + let_ (fst_ pars) $ \restpars -> + let_ (fst_ par * inp + snd_ par) $ \res -> + layer t restpars res + layer STNil _ inp = inp layer _ _ _ = error "Invalid layer inputs" - in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ + in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input diff --git a/src/Language.hs b/src/Language.hs new file mode 100644 index 0000000..b76e07f --- /dev/null +++ b/src/Language.hs @@ -0,0 +1,104 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE TypeOperators #-} +module Language ( + scopeCheck, + SExpr, + module Language, +) where + +import AST +import Data +import Language.AST + + +lambda :: forall a args t. KnownTy a => (SExpr a -> SFun args t) -> SFun (Append args '[a]) t +lambda f = case mkLambda f f of + Lambda tag (SFun args e) -> + SFun (sappend args (tag `SCons` SNil)) e + +body :: SExpr t -> SFun '[] t +body e = SFun SNil e + + +let_ :: KnownTy a => SExpr a -> (SExpr a -> SExpr t) -> SExpr t +let_ rhs f = SELet rhs (mkLambda (rhs, f) f) + +pair :: SExpr a -> SExpr b -> SExpr (TPair a b) +pair = SEPair + +fst_ :: SExpr (TPair a b) -> SExpr a +fst_ = SEFst + +snd_ :: SExpr (TPair a b) -> SExpr b +snd_ = SESnd + +nil :: SExpr TNil +nil = SENil + +inl :: STy b -> SExpr a -> SExpr (TEither a b) +inl = SEInl + +inr :: STy a -> SExpr b -> SExpr (TEither a b) +inr = SEInr + +case_ :: (KnownTy a, KnownTy b) + => SExpr (TEither a b) -> (SExpr a -> SExpr c) -> (SExpr b -> SExpr c) -> SExpr c +case_ e f g = SECase e (mkLambda (e, f) f) (mkLambda (e, g) g) + +build1 :: SExpr TIx -> (SExpr TIx -> SExpr t) -> SExpr (TArr (S Z) t) +build1 e f = SEBuild1 e (mkLambda (e, f) f) + +build :: SNat n -> SExpr (Tup (Replicate n TIx)) -> (SExpr (Tup (Replicate n TIx)) -> SExpr t) -> SExpr (TArr n t) +build n e f = SEBuild n e (mkLambda' (e, f) (tTup (sreplicate n tIx)) f) + +fold1 :: KnownTy t => (SExpr t -> SExpr t -> SExpr t) -> SExpr (TArr (S n) t) -> SExpr (TArr n t) +fold1 f e = SEFold1 (mkLambda2 (f, e) f) e + +unit :: SExpr t -> SExpr (TArr Z t) +unit = SEUnit + +const_ :: KnownScalTy t => ScalRep t -> SExpr (TScal t) +const_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> SEConst ty x + +idx0 :: SExpr (TArr Z t) -> SExpr t +idx0 = SEIdx0 + +(.!) :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t) +(.!) = SEIdx1 + +(!) :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t +(!) = SEIdx + +shape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) +shape = SEShape + +oper :: SOp a t -> SExpr a -> SExpr t +oper = SEOp + +error_ :: KnownTy t => String -> SExpr t +error_ s = SEError knownTy s + +(.==) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +a .== b = oper (OEq knownScalTy) (pair a b) + +(.<) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +a .< b = oper (OLt knownScalTy) (pair a b) + +(.>) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.>) = flip (.<) + +(.<=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +a .<= b = oper (OLe knownScalTy) (pair a b) + +(.>=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.>=) = flip (.<=) + +not_ :: SExpr (TScal TBool) -> SExpr (TScal TBool) +not_ = oper ONot + +if_ :: SExpr (TScal TBool) -> SExpr t -> SExpr t -> SExpr t +if_ e a b = case_ (oper OIf e) (\_ -> a) (\_ -> b) diff --git a/src/Language/AST.hs b/src/Language/AST.hs new file mode 100644 index 0000000..1c53c8a --- /dev/null +++ b/src/Language/AST.hs @@ -0,0 +1,134 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module Language.AST where + +import AST +import Data +import Data.Type.Equality +import Language.Tag + + +data SExpr t where + -- lambda calculus + SEVar :: Tag t -> SExpr t + SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t + + -- base types + SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b) + SEFst :: SExpr (TPair a b) -> SExpr a + SESnd :: SExpr (TPair a b) -> SExpr b + SENil :: SExpr TNil + SEInl :: STy b -> SExpr a -> SExpr (TEither a b) + SEInr :: STy a -> SExpr b -> SExpr (TEither a b) + SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c + + -- array operations + SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t) + SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t) + SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t) + SEUnit :: SExpr t -> SExpr (TArr Z t) + + -- expression operations + SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t) + SEIdx0 :: SExpr (TArr Z t) -> SExpr t + SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t) + SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t + SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) + SEOp :: SOp a t -> SExpr a -> SExpr t + + -- partiality + SEError :: STy a -> String -> SExpr a +deriving instance Show (SExpr t) + +data Lambda a b = Lambda (Tag a) b + deriving (Show) + +mkLambda :: KnownTy a => handle -> (SExpr a -> f t) -> Lambda a (f t) +mkLambda handle f = mkLambda' handle knownTy f + +mkLambda' :: handle -> STy a -> (SExpr a -> f t) -> Lambda a (f t) +mkLambda' handle ty f = + let tag = genTag handle ty + in Lambda tag (f (SEVar tag)) + +mkLambda2 :: (KnownTy a, KnownTy b) + => handle -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t)) +mkLambda2 handle f = mkLambda2' handle knownTy knownTy f + +mkLambda2' :: handle -> STy a -> STy b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t)) +mkLambda2' handle ty1 ty2 f = + let tag2 = genTag handle ty2 + lam2 = Lambda tag2 (f (SEVar tag1) (SEVar tag2)) + tag1 = genTag lam2 ty1 + in Lambda tag1 lam2 + +instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where + a + b = SEOp (OAdd knownScalTy) (SEPair a b) + a * b = SEOp (OMul knownScalTy) (SEPair a b) + negate e = SEOp (ONeg knownScalTy) e + abs = error "abs undefined for SExpr" + signum = error "signum undefined for SExpr" + fromInteger = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> SEConst ty . fromInteger + +data SFun args t = SFun (SList Tag args) (SExpr t) + +scopeCheck :: SFun env t -> Ex env t +scopeCheck (SFun args e) = scopeCheckExpr args e + +scopeCheckExpr :: forall env t. SList Tag env -> SExpr t -> Ex env t +scopeCheckExpr val = \case + SEVar tag@(Tag ty _) + | Just idx <- find tag val -> EVar ext ty idx + | otherwise -> error "Variable out of scope in conversion from surface \ + \expression to De Bruijn expression" + SELet a b -> ELet ext (go a) (lambda val b) + + SEPair a b -> EPair ext (go a) (go b) + SEFst e -> EFst ext (go e) + SESnd e -> ESnd ext (go e) + SENil -> ENil ext + SEInl t e -> EInl ext t (go e) + SEInr t e -> EInr ext t (go e) + SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b) + + SEBuild1 a b -> EBuild1 ext (go a) (lambda val b) + SEBuild n a b -> EBuild ext n (go a) (lambda val b) + SEFold1 a b -> EFold1 ext (lambda2 val a) (go b) + SEUnit e -> EUnit ext (go e) + + SEConst t x -> EConst ext t x + SEIdx0 e -> EIdx0 ext (go e) + SEIdx1 a b -> EIdx1 ext (go a) (go b) + SEIdx n a b -> EIdx ext n (go a) (go b) + SEShape e -> EShape ext (go e) + SEOp op e -> EOp ext op (go e) + + SEError t s -> EError t s + where + go :: SExpr t' -> Ex env t' + go = scopeCheckExpr val + + find :: Tag t' -> SList Tag env' -> Maybe (Idx env' t') + find _ SNil = Nothing + find tag@(Tag ty i) (Tag ty' i' `SCons` val') + | i == i' + , Just Refl <- testEquality ty ty' + = Just IZ + | otherwise + = IS <$> find tag val' + + lambda :: SList Tag env' -> Lambda a (SExpr b) -> Ex (a : env') b + lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e + + lambda2 :: SList Tag env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c + lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e diff --git a/src/Language/Tag.hs b/src/Language/Tag.hs new file mode 100644 index 0000000..9356073 --- /dev/null +++ b/src/Language/Tag.hs @@ -0,0 +1,22 @@ +{-# LANGUAGE BangPatterns #-} +module Language.Tag ( + Tag(..), genTag, +) where + +import Data.IORef +import System.IO.Unsafe + +import AST + + +data Tag t = Tag (STy t) Int + deriving (Show) + +{-# NOINLINE tagCounter #-} +tagCounter :: IORef Int +tagCounter = unsafePerformIO $ newIORef 1 + +{-# NOINLINE genTag #-} +genTag :: handle -> STy t -> Tag t +genTag !_ ty = + unsafePerformIO $ Tag ty <$> atomicModifyIORef' tagCounter (\i -> (succ i, i)) |