{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module CHAD.Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) import CHAD.Array import CHAD.AST import CHAD.AST.Sparse.Types import CHAD.Data import CHAD.Drev.Types -- | A named expression: variables have names, not De Bruijn indices. -- Otherwise essentially identical to 'Expr'. type NExpr :: [(Symbol, Ty)] -> Ty -> Type data NExpr env t where -- lambda calculus NEVar :: Lookup name env ~ t => Var name t -> NExpr env t NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -- environment management NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t -- base types NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) NEFst :: NExpr env (TPair a b) -> NExpr env a NESnd :: NExpr env (TPair a b) -> NExpr env b NENil :: NExpr env TNil NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c NENothing :: STy t -> NExpr env (TMaybe t) NEJust :: NExpr env t -> NExpr env (TMaybe t) NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) NEOp :: SOp a t -> NExpr env a -> NExpr env t -- custom derivatives NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative -> NExpr env a -> NExpr env b -> NExpr env t -- fake halfway checkpointing NERecompute :: NExpr env t -> NExpr env t -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a -- embedded unnamed expressions NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -- | Look up the type of a name in a named environment. type Lookup name env = Lookup1 (name == "_") name env -- | This curious stack of type families is used instead of normal pattern -- matching so the decidable boolean predicate "==" is used. This means that -- introducing evidence of @(name1 == name2) ~ False@ may allow a certain -- lookup to reduce even if the names in question are not statically known. -- This flexibility is used with e.g. 'assertSymbolDistinct' in -- 'CHAD.Language.fold1i'. type family Lookup1 eqblank name env where Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") Lookup1 False name env = Lookup2 name env type family Lookup2 name env where Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env type family Lookup3 eq t name env where Lookup3 True t _ _ = t Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env DropNth (S i) (p : env) = p : DropNth i env data Var name t = Var (SSymbol name) (STy t) deriving (Show) instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where a + b = NEOp (OAdd knownScalTy) (NEPair a b) a * b = NEOp (OMul knownScalTy) (NEPair a b) negate e = NEOp (ONeg knownScalTy) e abs = error "abs undefined for NExpr" signum = error "signum undefined for NExpr" fromInteger = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty . fromInteger instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) => Fractional (NExpr env t) where recip e = NEOp (ORecip knownScalTy) e fromRational = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty . fromRational instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) => Floating (NExpr env t) where pi = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty pi exp = NEOp (OExp knownScalTy) log = NEOp (OExp knownScalTy) sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined sinh = undefined ; cosh = undefined asinh = undefined ; acosh = undefined ; atanh = undefined instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where fromLabel = Var symbolSing knownTy instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where fromLabel = NEVar (fromLabel @name) -- | Innermost variable variable on the outside, on the right. data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) -- | A named /function/. These can be used in only two ways: they can be -- converted to an unnamed 'Expr' using 'fromNamed', and they can be inlined -- using 'CHAD.Language.inline'. -- -- * @env@: environment of this function (smaller than @env'@; grows as you descend under lambdas) -- * @env'@: environment of the body of the function -- -- For example, a function @(\\(x :: a) (y :: b) -> _ :: c)@ with two free -- variables, @u :: t1@ and @v :: t2@, would be represented with a value of the -- following type: -- -- @ -- NFun '['("v", t2), '("u", t1)] '['("y", b), '("x", a), '("v", t2), '("u", t1)] c -- @ data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t NBody :: NExpr env' t -> NFun env' env' t type family UnName env where UnName '[] = '[] UnName ('(name, t) : env) = t : UnName env envFromNEnv :: NEnv env -> SList STy (UnName env) envFromNEnv NTop = SNil envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t inlineNFun fun args = NEUnnamed (fromNamed fun) args -- | Convert a named function to an unnamed expression with free variables, -- ready for consumption by the rest of this library. The function must be -- closed (meaning that the function as a whole cannot have free variables), -- and the arguments of the function are realised as free variables of the -- resulting expression. Typical usage looks as follows: -- -- @ -- {-# LANGUAGE OverloadedLabels #-} -- import CHAD.Language -- 'fromNamed' $ 'CHAD.Language.lambda' \@(TScal TF64) #x $ 'CHAD.Language.lambda' \@(TScal TI64) #i $ 'CHAD.Language.body' $ #x + 'CHAD.Language.toFloat_' #i -- :: 'Ex' '[TScal TI64, TScal TF64] (TScal TF64) -- @ -- -- The rest of the library generally considers expressions with free variables -- to stand in for "functions", by considering the free variables as the -- function's inputs. -- -- Note that while environments normally grow to the right (e.g. in type theory -- notation), as they as type-level lists here, they grow to the /left/. This -- is why the second (innermost) argument of the example, @i@, ends up at the -- head of the environment of the constructed expression. -- -- __Type applications__: The type applications to 'CHAD.Language.lambda' above -- are good practice, but not always necessary; if GHC can infer the type of -- the argument from the body of the expression, the type application is -- unnecessary. -- -- __Variables__: The major element of syntactic sugar in this module is using -- OverloadedLabels for variable names. Variables are represented in 'NExpr' -- (and thus 'NFun') using the 'Var' type; you should never have to manually -- construct a 'Var'. Instead, 'Var' implements 'IsLabel' and as such can be -- produced with the syntax @#name@, where "name" is the name of the variable. -- This syntax produces a polymorphic variable reference whose (embedded) type -- is left to GHC's type inference engine using a 'KnownTy' constraint. See -- also 'CHAD.Language.let_'. fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop -- | Some of the parameters have already been put in the environment; some -- haven't. Transfer all parameters to the left into the environment. -- -- [] `fromNamedFun` λx y z. E -- = []:x `fromNamedFun` λy z. E -- = []:x:y `fromNamedFun` λz. E -- = []:x:y:z `fromNamedFun` λ. E -- = []:x:y:z `fromNamedExpr` E fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun fromNamedFun env (NBody e) = fromNamedExpr env e fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t fromNamedExpr val = \case NEVar var@(Var _ ty) | Just idx <- find var val -> EVar ext ty idx | otherwise -> error "Variable out of scope in conversion from surface \ \expression to De Bruijn expression" NELet n a b -> ELet ext (go a) (lambda val n b) NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) NEPair a b -> EPair ext (go a) (go b) NEFst e -> EFst ext (go e) NESnd e -> ESnd ext (go e) NENil -> ENil ext NEInl t e -> EInl ext t (go e) NEInr t e -> EInr ext t (go e) NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) NENothing t -> ENothing ext t NEJust e -> EJust ext (go e) NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) NEMap n a b -> EMap ext (lambda val n a) (go b) NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) NEIdx1 a b -> EIdx1 ext (go a) (go b) NEIdx a b -> EIdx ext (go a) (go b) NEShape e -> EShape ext (go e) NEOp op e -> EOp ext op (go e) NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> ECustom ext ta tb ttape (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args where go :: NExpr env t' -> Ex (UnName env) t' go = fromNamedExpr val find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') find _ NTop = Nothing find var@(Var s ty) (val' `NPush` Var s' ty') | Just Refl <- testEquality s s' , Just Refl <- testEquality ty ty' = Just IZ | otherwise = IS <$> find var val' lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b lambda val' var e = fromNamedExpr (val' `NPush` var) e lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t injectWrapLet e SNil = e injectWrapLet e (arg `SCons` args) = injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) args dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) dropNth SZ (val `NPush` _) = val dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p dropNth _ NTop = error "DropNth: index out of range" dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r assertSymbolNotUnderscore s@SSymbol k = case symbolVal s of "_" -> error "assertSymbolNotUnderscore: was underscore" _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r assertSymbolDistinct s1@SSymbol s2@SSymbol k | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k