{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) import Array import AST import AST.Sparse.Types import CHAD.Types import Data type NExpr :: [(Symbol, Ty)] -> Ty -> Type data NExpr env t where -- lambda calculus NEVar :: Lookup name env ~ t => Var name t -> NExpr env t NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -- environment management NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t -- base types NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) NEFst :: NExpr env (TPair a b) -> NExpr env a NESnd :: NExpr env (TPair a b) -> NExpr env b NENil :: NExpr env TNil NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c NENothing :: STy t -> NExpr env (TMaybe t) NEJust :: NExpr env t -> NExpr env (TMaybe t) NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) NEOp :: SOp a t -> NExpr env a -> NExpr env t -- custom derivatives NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative -> NExpr env a -> NExpr env b -> NExpr env t -- fake halfway checkpointing NERecompute :: NExpr env t -> NExpr env t -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a -- embedded unnamed expressions NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) type Lookup name env = Lookup1 (name == "_") name env type family Lookup1 eqblank name env where Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") Lookup1 False name env = Lookup2 name env type family Lookup2 name env where Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env type family Lookup3 eq t name env where Lookup3 True t _ _ = t Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env DropNth (S i) (p : env) = p : DropNth i env data Var name t = Var (SSymbol name) (STy t) deriving (Show) instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where a + b = NEOp (OAdd knownScalTy) (NEPair a b) a * b = NEOp (OMul knownScalTy) (NEPair a b) negate e = NEOp (ONeg knownScalTy) e abs = error "abs undefined for NExpr" signum = error "signum undefined for NExpr" fromInteger = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty . fromInteger instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) => Fractional (NExpr env t) where recip e = NEOp (ORecip knownScalTy) e fromRational = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty . fromRational instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) => Floating (NExpr env t) where pi = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty pi exp = NEOp (OExp knownScalTy) log = NEOp (OExp knownScalTy) sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined sinh = undefined ; cosh = undefined asinh = undefined ; acosh = undefined ; atanh = undefined instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where fromLabel = Var symbolSing knownTy instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where fromLabel = NEVar (fromLabel @name) -- | Innermost variable variable on the outside, on the right. data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) -- | First (outermost) parameter on the outside, on the left. -- * env: environment of this function (grows as you go deeper inside lambdas) -- * env': environment of the body of the function -- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t NBody :: NExpr env' t -> NFun env' env' t type family UnName env where UnName '[] = '[] UnName ('(name, t) : env) = t : UnName env envFromNEnv :: NEnv env -> SList STy (UnName env) envFromNEnv NTop = SNil envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t inlineNFun fun args = NEUnnamed (fromNamed fun) args fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop -- | Some of the parameters have already been put in the environment; some -- haven't. Transfer all parameters to the left into the environment. -- -- [] `fromNamedFun` λx y z. E -- = []:x `fromNamedFun` λy z. E -- = []:x:y `fromNamedFun` λz. E -- = []:x:y:z `fromNamedFun` λ. E -- = []:x:y:z `fromNamedExpr` E fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun fromNamedFun env (NBody e) = fromNamedExpr env e fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t fromNamedExpr val = \case NEVar var@(Var _ ty) | Just idx <- find var val -> EVar ext ty idx | otherwise -> error "Variable out of scope in conversion from surface \ \expression to De Bruijn expression" NELet n a b -> ELet ext (go a) (lambda val n b) NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) NEPair a b -> EPair ext (go a) (go b) NEFst e -> EFst ext (go e) NESnd e -> ESnd ext (go e) NENil -> ENil ext NEInl t e -> EInl ext t (go e) NEInr t e -> EInr ext t (go e) NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) NENothing t -> ENothing ext t NEJust e -> EJust ext (go e) NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) NEIdx1 a b -> EIdx1 ext (go a) (go b) NEIdx a b -> EIdx ext (go a) (go b) NEShape e -> EShape ext (go e) NEOp op e -> EOp ext op (go e) NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> ECustom ext ta tb ttape (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args where go :: NExpr env t' -> Ex (UnName env) t' go = fromNamedExpr val find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') find _ NTop = Nothing find var@(Var s ty) (val' `NPush` Var s' ty') | Just Refl <- testEquality s s' , Just Refl <- testEquality ty ty' = Just IZ | otherwise = IS <$> find var val' lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b lambda val' var e = fromNamedExpr (val' `NPush` var) e lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t injectWrapLet e SNil = e injectWrapLet e (arg `SCons` args) = injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) args dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) dropNth SZ (val `NPush` _) = val dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p dropNth _ NTop = error "DropNth: index out of range" dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r assertSymbolNotUnderscore s@SSymbol k = case symbolVal s of "_" -> error "assertSymbolNotUnderscore: was underscore" _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r assertSymbolDistinct s1@SSymbol s2@SSymbol k | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k