aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Language
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Language')
-rw-r--r--src/CHAD/Language/AST.hs300
1 files changed, 300 insertions, 0 deletions
diff --git a/src/CHAD/Language/AST.hs b/src/CHAD/Language/AST.hs
new file mode 100644
index 0000000..b270844
--- /dev/null
+++ b/src/CHAD/Language/AST.hs
@@ -0,0 +1,300 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.Language.AST where
+
+import Data.Kind (Type)
+import Data.Type.Equality
+import GHC.OverloadedLabels
+import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal)
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Sparse.Types
+import CHAD.Data
+import CHAD.Drev.Types
+
+
+type NExpr :: [(Symbol, Ty)] -> Ty -> Type
+data NExpr env t where
+ -- lambda calculus
+ NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
+ NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
+
+ -- environment management
+ NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t
+
+ -- base types
+ NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
+ NEFst :: NExpr env (TPair a b) -> NExpr env a
+ NESnd :: NExpr env (TPair a b) -> NExpr env b
+ NENil :: NExpr env TNil
+ NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
+ NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
+ NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c
+ NENothing :: STy t -> NExpr env (TMaybe t)
+ NEJust :: NExpr env t -> NExpr env (TMaybe t)
+ NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b
+
+ -- array operations
+ NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
+ NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
+ NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEUnit :: NExpr env t -> NExpr env (TArr Z t)
+ NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
+ NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+
+ NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b)
+ -> NExpr env t1
+ -> NExpr env (TArr (S n) t1)
+ -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+ NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
+ -> NExpr env (TArr (S n) b)
+ -> NExpr env (TArr n t2)
+ -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+
+ -- expression operations
+ NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
+ NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
+ NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
+ NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+ NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
+ NEOp :: SOp a t -> NExpr env a -> NExpr env t
+
+ -- custom derivatives
+ NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation
+ -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+
+ -- fake halfway checkpointing
+ NERecompute :: NExpr env t -> NExpr env t
+
+ -- accumulation effect on monoids
+ NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+
+ -- partiality
+ NEError :: STy a -> String -> NExpr env a
+
+ -- embedded unnamed expressions
+ NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
+deriving instance Show (NExpr env t)
+
+type Lookup name env = Lookup1 (name == "_") name env
+type family Lookup1 eqblank name env where
+ Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
+ Lookup1 False name env = Lookup2 name env
+type family Lookup2 name env where
+ Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
+ Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env
+type family Lookup3 eq t name env where
+ Lookup3 True t _ _ = t
+ Lookup3 False _ name env = Lookup2 name env
+
+type family DropNth i env where
+ DropNth Z (_ : env) = env
+ DropNth (S i) (p : env) = p : DropNth i env
+
+data Var name t = Var (SSymbol name) (STy t)
+ deriving (Show)
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
+ a + b = NEOp (OAdd knownScalTy) (NEPair a b)
+ a * b = NEOp (OMul knownScalTy) (NEPair a b)
+ negate e = NEOp (ONeg knownScalTy) e
+ abs = error "abs undefined for NExpr"
+ signum = error "signum undefined for NExpr"
+ fromInteger =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty . fromInteger
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
+ => Fractional (NExpr env t) where
+ recip e = NEOp (ORecip knownScalTy) e
+ fromRational =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty . fromRational
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
+ => Floating (NExpr env t) where
+ pi =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty pi
+ exp = NEOp (OExp knownScalTy)
+ log = NEOp (OExp knownScalTy)
+ sin = undefined ; cos = undefined ; tan = undefined
+ asin = undefined ; acos = undefined ; atan = undefined
+ sinh = undefined ; cosh = undefined
+ asinh = undefined ; acosh = undefined ; atanh = undefined
+
+instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
+ fromLabel = Var symbolSing knownTy
+
+instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
+ fromLabel = NEVar (fromLabel @name)
+
+-- | Innermost variable variable on the outside, on the right.
+data NEnv env where
+ NTop :: NEnv '[]
+ NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)
+
+-- | First (outermost) parameter on the outside, on the left.
+-- * env: environment of this function (grows as you go deeper inside lambdas)
+-- * env': environment of the body of the function
+-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
+data NFun env env' t where
+ NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
+ NBody :: NExpr env' t -> NFun env' env' t
+
+type family UnName env where
+ UnName '[] = '[]
+ UnName ('(name, t) : env) = t : UnName env
+
+envFromNEnv :: NEnv env -> SList STy (UnName env)
+envFromNEnv NTop = SNil
+envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env
+
+inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
+inlineNFun fun args = NEUnnamed (fromNamed fun) args
+
+fromNamed :: NFun '[] env t -> Ex (UnName env) t
+fromNamed = fromNamedFun NTop
+
+-- | Some of the parameters have already been put in the environment; some
+-- haven't. Transfer all parameters to the left into the environment.
+--
+-- [] `fromNamedFun` λx y z. E
+-- = []:x `fromNamedFun` λy z. E
+-- = []:x:y `fromNamedFun` λz. E
+-- = []:x:y:z `fromNamedFun` λ. E
+-- = []:x:y:z `fromNamedExpr` E
+fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
+fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
+fromNamedFun env (NBody e) = fromNamedExpr env e
+
+fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
+fromNamedExpr val = \case
+ NEVar var@(Var _ ty)
+ | Just idx <- find var val -> EVar ext ty idx
+ | otherwise -> error "Variable out of scope in conversion from surface \
+ \expression to De Bruijn expression"
+ NELet n a b -> ELet ext (go a) (lambda val n b)
+
+ NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)
+
+ NEPair a b -> EPair ext (go a) (go b)
+ NEFst e -> EFst ext (go e)
+ NESnd e -> ESnd ext (go e)
+ NENil -> ENil ext
+ NEInl t e -> EInl ext t (go e)
+ NEInr t e -> EInr ext t (go e)
+ NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
+ NENothing t -> ENothing ext t
+ NEJust e -> EJust ext (go e)
+ NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c)
+
+ NEConstArr n t x -> EConstArr ext n t x
+ NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
+ NEMap n a b -> EMap ext (lambda val n a) (go b)
+ NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
+ NESum1Inner e -> ESum1Inner ext (go e)
+ NEUnit e -> EUnit ext (go e)
+ NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
+ NEMaximum1Inner e -> EMaximum1Inner ext (go e)
+ NEMinimum1Inner e -> EMinimum1Inner ext (go e)
+ NEReshape n a b -> EReshape ext n (go a) (go b)
+
+ NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c)
+ NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+
+ NEConst t x -> EConst ext t x
+ NEIdx0 e -> EIdx0 ext (go e)
+ NEIdx1 a b -> EIdx1 ext (go a) (go b)
+ NEIdx a b -> EIdx ext (go a) (go b)
+ NEShape e -> EShape ext (go e)
+ NEOp op e -> EOp ext op (go e)
+
+ NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
+ ECustom ext ta tb ttape
+ (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
+ (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
+ (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
+ (go e1) (go e2)
+ NERecompute e -> ERecompute ext (go e)
+
+ NEWith t a n b -> EWith ext t (go a) (lambda val n b)
+ NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
+
+ NEError t s -> EError ext t s
+
+ NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
+ where
+ go :: NExpr env t' -> Ex (UnName env) t'
+ go = fromNamedExpr val
+
+ find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
+ find _ NTop = Nothing
+ find var@(Var s ty) (val' `NPush` Var s' ty')
+ | Just Refl <- testEquality s s'
+ , Just Refl <- testEquality ty ty'
+ = Just IZ
+ | otherwise
+ = IS <$> find var val'
+
+ lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
+ lambda val' var e = fromNamedExpr (val' `NPush` var) e
+
+ lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
+ lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e
+
+ injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
+ injectWrapLet e SNil = e
+ injectWrapLet e (arg `SCons` args) =
+ injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
+ args
+
+dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
+dropNth SZ (val `NPush` _) = val
+dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
+dropNth _ NTop = error "DropNth: index out of range"
+
+dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
+dropNthW SZ (_ `NPush` _) = WSink
+dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
+dropNthW _ NTop = error "DropNth: index out of range"
+
+assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r
+assertSymbolNotUnderscore s@SSymbol k =
+ case symbolVal s of
+ "_" -> error "assertSymbolNotUnderscore: was underscore"
+ _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k
+
+assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r
+assertSymbolDistinct s1@SSymbol s2@SSymbol k
+ | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")"
+ | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k
+
+equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r
+equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k