summaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r--src/Language/AST.hs134
1 files changed, 134 insertions, 0 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
new file mode 100644
index 0000000..1c53c8a
--- /dev/null
+++ b/src/Language/AST.hs
@@ -0,0 +1,134 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module Language.AST where
+
+import AST
+import Data
+import Data.Type.Equality
+import Language.Tag
+
+
+data SExpr t where
+ -- lambda calculus
+ SEVar :: Tag t -> SExpr t
+ SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t
+
+ -- base types
+ SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b)
+ SEFst :: SExpr (TPair a b) -> SExpr a
+ SESnd :: SExpr (TPair a b) -> SExpr b
+ SENil :: SExpr TNil
+ SEInl :: STy b -> SExpr a -> SExpr (TEither a b)
+ SEInr :: STy a -> SExpr b -> SExpr (TEither a b)
+ SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c
+
+ -- array operations
+ SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t)
+ SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t)
+ SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t)
+ SEUnit :: SExpr t -> SExpr (TArr Z t)
+
+ -- expression operations
+ SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t)
+ SEIdx0 :: SExpr (TArr Z t) -> SExpr t
+ SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t)
+ SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t
+ SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx))
+ SEOp :: SOp a t -> SExpr a -> SExpr t
+
+ -- partiality
+ SEError :: STy a -> String -> SExpr a
+deriving instance Show (SExpr t)
+
+data Lambda a b = Lambda (Tag a) b
+ deriving (Show)
+
+mkLambda :: KnownTy a => handle -> (SExpr a -> f t) -> Lambda a (f t)
+mkLambda handle f = mkLambda' handle knownTy f
+
+mkLambda' :: handle -> STy a -> (SExpr a -> f t) -> Lambda a (f t)
+mkLambda' handle ty f =
+ let tag = genTag handle ty
+ in Lambda tag (f (SEVar tag))
+
+mkLambda2 :: (KnownTy a, KnownTy b)
+ => handle -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
+mkLambda2 handle f = mkLambda2' handle knownTy knownTy f
+
+mkLambda2' :: handle -> STy a -> STy b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
+mkLambda2' handle ty1 ty2 f =
+ let tag2 = genTag handle ty2
+ lam2 = Lambda tag2 (f (SEVar tag1) (SEVar tag2))
+ tag1 = genTag lam2 ty1
+ in Lambda tag1 lam2
+
+instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where
+ a + b = SEOp (OAdd knownScalTy) (SEPair a b)
+ a * b = SEOp (OMul knownScalTy) (SEPair a b)
+ negate e = SEOp (ONeg knownScalTy) e
+ abs = error "abs undefined for SExpr"
+ signum = error "signum undefined for SExpr"
+ fromInteger =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> SEConst ty . fromInteger
+
+data SFun args t = SFun (SList Tag args) (SExpr t)
+
+scopeCheck :: SFun env t -> Ex env t
+scopeCheck (SFun args e) = scopeCheckExpr args e
+
+scopeCheckExpr :: forall env t. SList Tag env -> SExpr t -> Ex env t
+scopeCheckExpr val = \case
+ SEVar tag@(Tag ty _)
+ | Just idx <- find tag val -> EVar ext ty idx
+ | otherwise -> error "Variable out of scope in conversion from surface \
+ \expression to De Bruijn expression"
+ SELet a b -> ELet ext (go a) (lambda val b)
+
+ SEPair a b -> EPair ext (go a) (go b)
+ SEFst e -> EFst ext (go e)
+ SESnd e -> ESnd ext (go e)
+ SENil -> ENil ext
+ SEInl t e -> EInl ext t (go e)
+ SEInr t e -> EInr ext t (go e)
+ SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b)
+
+ SEBuild1 a b -> EBuild1 ext (go a) (lambda val b)
+ SEBuild n a b -> EBuild ext n (go a) (lambda val b)
+ SEFold1 a b -> EFold1 ext (lambda2 val a) (go b)
+ SEUnit e -> EUnit ext (go e)
+
+ SEConst t x -> EConst ext t x
+ SEIdx0 e -> EIdx0 ext (go e)
+ SEIdx1 a b -> EIdx1 ext (go a) (go b)
+ SEIdx n a b -> EIdx ext n (go a) (go b)
+ SEShape e -> EShape ext (go e)
+ SEOp op e -> EOp ext op (go e)
+
+ SEError t s -> EError t s
+ where
+ go :: SExpr t' -> Ex env t'
+ go = scopeCheckExpr val
+
+ find :: Tag t' -> SList Tag env' -> Maybe (Idx env' t')
+ find _ SNil = Nothing
+ find tag@(Tag ty i) (Tag ty' i' `SCons` val')
+ | i == i'
+ , Just Refl <- testEquality ty ty'
+ = Just IZ
+ | otherwise
+ = IS <$> find tag val'
+
+ lambda :: SList Tag env' -> Lambda a (SExpr b) -> Ex (a : env') b
+ lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e
+
+ lambda2 :: SList Tag env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c
+ lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e