aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Language.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Language.hs')
-rw-r--r--src/CHAD/Language.hs266
1 files changed, 266 insertions, 0 deletions
diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs
new file mode 100644
index 0000000..6dc91a5
--- /dev/null
+++ b/src/CHAD/Language.hs
@@ -0,0 +1,266 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitForAll #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+module CHAD.Language (
+ fromNamed,
+ NExpr,
+ Ex,
+ module CHAD.Language,
+ module CHAD.AST.Types,
+ Lookup,
+) where
+
+import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Sparse.Types
+import CHAD.AST.Types
+import CHAD.Data
+import CHAD.Drev.Types
+import CHAD.Language.AST
+
+
+data a :-> b = a :-> b
+ deriving (Show)
+infixr 0 :->
+
+
+body :: NExpr env t -> NFun env env t
+body = NBody
+
+lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
+lambda = NLam
+
+inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
+inline = inlineNFun
+
+-- To be used to construct the argument list for 'inline'.
+--
+-- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y
+-- > in inline fun (SNil .$ 16 .$ 26)
+(.$) :: SList f list -> f a -> SList f (a : list)
+(.$) = flip SCons
+
+
+let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
+let_ = NELet
+
+pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
+pair = NEPair
+
+fst_ :: NExpr env (TPair a b) -> NExpr env a
+fst_ = NEFst
+
+snd_ :: NExpr env (TPair a b) -> NExpr env b
+snd_ = NESnd
+
+nil :: NExpr env TNil
+nil = NENil
+
+inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b)
+inl = NEInl knownTy
+
+inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b)
+inr = NEInr knownTy
+
+case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
+case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
+
+nothing :: KnownTy a => NExpr env (TMaybe a)
+nothing = NENothing knownTy
+
+just :: NExpr env a -> NExpr env (TMaybe a)
+just = NEJust
+
+maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b
+maybe_ a (v :-> b) c = NEMaybe a v b c
+
+constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
+constArr_ x =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConstArr knownNat ty x
+
+build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
+build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
+
+build2 :: NExpr env TIx -> NExpr env TIx
+ -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
+ -> NExpr env (TArr (S (S Z)) t)
+build2 a1 a2 (v1 :-> v2 :-> b) =
+ NEBuild (SS (SS SZ))
+ (pair (pair nil a1) a2)
+ #idx
+ (let_ v1 (snd_ (fst_ #idx)) $
+ let_ v2 (NEDrop SZ (snd_ #idx)) $
+ NEDrop (SS (SS SZ)) b)
+
+build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
+build n a (v :-> b) = NEBuild n a v b
+
+map_ :: forall n a b env name. (KnownNat n, KnownTy a)
+ => (Var name a :-> NExpr ('(name, a) : env) b)
+ -> NExpr env (TArr n a) -> NExpr env (TArr n b)
+map_ (v :-> a) b = NEMap v a b
+
+fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t t)
+ in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3
+
+sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+sum1i e = NESum1Inner e
+
+unit :: NExpr env t -> NExpr env (TArr Z t)
+unit = NEUnit
+
+replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
+replicate1i n a = NEReplicate1Inner n a
+
+maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+maximum1i e = NEMaximum1Inner e
+
+minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+minimum1i e = NEMinimum1Inner e
+
+reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+reshape = NEReshape
+
+fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t1 t1)
+ in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3
+
+fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
+ -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
+
+const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
+const_ x =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty x
+
+idx0 :: NExpr env (TArr Z t) -> NExpr env t
+idx0 = NEIdx0
+
+-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
+-- (.!) = NEIdx1
+-- infixl 9 .!
+
+(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+(!) = NEIdx
+infixl 9 !
+
+shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
+shape = NEShape
+
+length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
+length_ e = snd_ (shape e)
+
+oper :: SOp a t -> NExpr env a -> NExpr env t
+oper = NEOp
+
+oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t
+oper2 op a b = NEOp op (pair a b)
+
+error_ :: KnownTy t => String -> NExpr env t
+error_ s = NEError knownTy s
+
+custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
+ -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
+ -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
+ NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+
+recompute :: NExpr env a -> NExpr env a
+recompute = NERecompute
+
+with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
+with a (n :-> b) = NEWith (knownMTy @t) a n b
+
+accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+
+accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+accumS p a sp b c = NEAccum knownMTy p a sp b c
+
+
+(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+a .== b = oper (OEq knownScalTy) (pair a b)
+infix 4 .==
+
+(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+a .< b = oper (OLt knownScalTy) (pair a b)
+infix 4 .<
+
+(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.>) = flip (.<)
+infix 4 .>
+
+(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+a .<= b = oper (OLe knownScalTy) (pair a b)
+infix 4 .<=
+
+(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.>=) = flip (.<=)
+infix 4 .>=
+
+not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+not_ = oper ONot
+
+and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+and_ = oper2 OAnd
+infixr 3 `and_`
+
+or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+or_ = oper2 OOr
+infixr 2 `or_`
+
+mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a)
+mod_ = oper2 (OMod knownScalTy)
+infixl 7 `mod_`
+
+-- | The first alternative is the True case; the second is the False case.
+if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
+if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
+
+round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64)
+round_ = oper ORound64
+
+toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64)
+toFloat_ = oper OToFl64
+
+idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t)
+idiv = oper2 (OIDiv knownScalTy)
+infixl 7 `idiv`