{-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} module Language ( fromNamed, NExpr, Ex, module Language, module AST.Types, module Data, Lookup, ) where import Array import AST import AST.Types import CHAD.Types import Data import Language.AST data a :-> b = a :-> b deriving (Show) infixr 0 :-> body :: NExpr env t -> NFun env env t body = NBody lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t lambda = NLam inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t inline = inlineNFun -- To be used to construct the argument list for 'inline'. -- -- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y -- > in inline fun (SNil .$ 16 .$ 26) (.$) :: SList f list -> f a -> SList f (a : list) (.$) = flip SCons let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t let_ = NELet pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) pair = NEPair fst_ :: NExpr env (TPair a b) -> NExpr env a fst_ = NEFst snd_ :: NExpr env (TPair a b) -> NExpr env b snd_ = NESnd nil :: NExpr env TNil nil = NENil inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) inl = NEInl knownTy inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) inr = NEInr knownTy case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) constArr_ x = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConstArr knownNat ty x build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) build2 :: NExpr env TIx -> NExpr env TIx -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) -> NExpr env (TArr (S (S Z)) t) build2 a1 a2 (v1 :-> v2 :-> b) = NEBuild (SS (SS SZ)) (pair (pair nil a1) a2) #idx (let_ v1 (snd_ (fst_ #idx)) $ let_ v2 (NEDrop SZ (snd_ #idx)) $ NEDrop (SS (SS SZ)) b) build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) build n a (v :-> b) = NEBuild n a v b map_ :: forall n a b env name. (KnownNat n, KnownTy a) => (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TArr n a) -> NExpr env (TArr n b) map_ (v :-> a) b | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) = let_ #arg b $ build knownNat (shape #arg) $ #i :-> let_ v (#arg ! #i) $ NEDrop (SS SZ) (NEDrop (SS SZ) a) fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) sum1i e = NESum1Inner e unit :: NExpr env t -> NExpr env (TArr Z t) unit = NEUnit replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) replicate1i n a = NEReplicate1Inner n a maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) maximum1i e = NEMaximum1Inner e minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) minimum1i e = NEMinimum1Inner e const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty x idx0 :: NExpr env (TArr Z t) -> NExpr env t idx0 = NEIdx0 -- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) -- (.!) = NEIdx1 -- infixl 9 .! (!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t (!) = NEIdx infixl 9 ! shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) shape = NEShape oper :: SOp a t -> NExpr env a -> NExpr env t oper = NEOp oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t oper2 op a b = NEOp op (pair a b) error_ :: KnownTy t => String -> NExpr env t error_ s = NEError knownTy s custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) -> NExpr env a -> NExpr env b -> NExpr env t custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .== b = oper (OEq knownScalTy) (pair a b) infix 4 .== (.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .< b = oper (OLt knownScalTy) (pair a b) infix 4 .< (.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) (.>) = flip (.<) infix 4 .> (.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .<= b = oper (OLe knownScalTy) (pair a b) infix 4 .<= (.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) (.>=) = flip (.<=) infix 4 .>= not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) not_ = oper ONot and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) and_ = oper2 OAnd infixr 3 `and_` or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) or_ = oper2 OOr infixr 2 `or_` -- | The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) round_ = oper ORound64 toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) toFloat_ = oper OToFl64 idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) idiv = oper2 (OIDiv knownScalTy) infixl 7 `idiv`