{-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE TypeOperators #-} module Language ( scopeCheck, SExpr, module Language, ) where import AST import Data import Language.AST lambda :: forall a args t. KnownTy a => (SExpr a -> SFun args t) -> SFun (Append args '[a]) t lambda f = case mkLambda f f of Lambda tag (SFun args e) -> SFun (sappend args (tag `SCons` SNil)) e body :: SExpr t -> SFun '[] t body e = SFun SNil e let_ :: KnownTy a => SExpr a -> (SExpr a -> SExpr t) -> SExpr t let_ rhs f = SELet rhs (mkLambda (rhs, f) f) pair :: SExpr a -> SExpr b -> SExpr (TPair a b) pair = SEPair fst_ :: SExpr (TPair a b) -> SExpr a fst_ = SEFst snd_ :: SExpr (TPair a b) -> SExpr b snd_ = SESnd nil :: SExpr TNil nil = SENil inl :: STy b -> SExpr a -> SExpr (TEither a b) inl = SEInl inr :: STy a -> SExpr b -> SExpr (TEither a b) inr = SEInr case_ :: (KnownTy a, KnownTy b) => SExpr (TEither a b) -> (SExpr a -> SExpr c) -> (SExpr b -> SExpr c) -> SExpr c case_ e f g = SECase e (mkLambda (e, f) f) (mkLambda (e, g) g) build1 :: SExpr TIx -> (SExpr TIx -> SExpr t) -> SExpr (TArr (S Z) t) build1 e f = SEBuild1 e (mkLambda (e, f) f) build :: SNat n -> SExpr (Tup (Replicate n TIx)) -> (SExpr (Tup (Replicate n TIx)) -> SExpr t) -> SExpr (TArr n t) build n e f = SEBuild n e (mkLambda' (e, f) (tTup (sreplicate n tIx)) f) fold1 :: KnownTy t => (SExpr t -> SExpr t -> SExpr t) -> SExpr (TArr (S n) t) -> SExpr (TArr n t) fold1 f e = SEFold1 (mkLambda2 (f, e) f) e unit :: SExpr t -> SExpr (TArr Z t) unit = SEUnit const_ :: KnownScalTy t => ScalRep t -> SExpr (TScal t) const_ x = let ty = knownScalTy in case scalRepIsShow ty of Dict -> SEConst ty x idx0 :: SExpr (TArr Z t) -> SExpr t idx0 = SEIdx0 (.!) :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t) (.!) = SEIdx1 (!) :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t (!) = SEIdx shape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) shape = SEShape oper :: SOp a t -> SExpr a -> SExpr t oper = SEOp error_ :: KnownTy t => String -> SExpr t error_ s = SEError knownTy s (.==) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) a .== b = oper (OEq knownScalTy) (pair a b) (.<) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) a .< b = oper (OLt knownScalTy) (pair a b) (.>) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) (.>) = flip (.<) (.<=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) a .<= b = oper (OLe knownScalTy) (pair a b) (.>=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) (.>=) = flip (.<=) not_ :: SExpr (TScal TBool) -> SExpr (TScal TBool) not_ = oper ONot if_ :: SExpr (TScal TBool) -> SExpr t -> SExpr t -> SExpr t if_ e a b = case_ (oper OIf e) (\_ -> a) (\_ -> b)