{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text)) import Array import AST import CHAD.Types import Data type NExpr :: [(Symbol, Ty)] -> Ty -> Type data NExpr env t where -- lambda calculus NEVar :: Lookup name env ~ t => Var name t -> NExpr env t NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -- environment management NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t -- base types NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) NEFst :: NExpr env (TPair a b) -> NExpr env a NESnd :: NExpr env (TPair a b) -> NExpr env b NENil :: NExpr env TNil NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) NEOp :: SOp a t -> NExpr env a -> NExpr env t -- custom derivatives NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative -> NExpr env a -> NExpr env b -> NExpr env t -- partiality NEError :: STy a -> String -> NExpr env a -- embedded unnamed expressions NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) type family Lookup name env where Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") Lookup name ('(name, t) : env) = t Lookup name (_ : env) = Lookup name env type family DropNth i env where DropNth Z (_ : env) = env DropNth (S i) (p : env) = p : DropNth i env data Var name t = Var (SSymbol name) (STy t) deriving (Show) instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where a + b = NEOp (OAdd knownScalTy) (NEPair a b) a * b = NEOp (OMul knownScalTy) (NEPair a b) negate e = NEOp (ONeg knownScalTy) e abs = error "abs undefined for NExpr" signum = error "signum undefined for NExpr" fromInteger = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty . fromInteger instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) => Fractional (NExpr env t) where recip e = NEOp (ORecip knownScalTy) e fromRational = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty . fromRational instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) => Floating (NExpr env t) where pi = let ty = knownScalTy in case scalRepIsShow ty of Dict -> NEConst ty pi exp = NEOp (OExp knownScalTy) log = NEOp (OExp knownScalTy) sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined sinh = undefined ; cosh = undefined asinh = undefined ; acosh = undefined ; atanh = undefined instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where fromLabel = Var symbolSing knownTy instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where fromLabel = NEVar (fromLabel @name) -- | Innermost variable variable on the outside, on the right. data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) -- | First (outermost) parameter on the outside, on the left. -- * env: environment of this function (grows as you go deeper inside lambdas) -- * env': environment of the body of the function -- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t NBody :: NExpr env' t -> NFun env' env' t type family UnName env where UnName '[] = '[] UnName ('(name, t) : env) = t : UnName env envFromNEnv :: NEnv env -> SList STy (UnName env) envFromNEnv NTop = SNil envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t inlineNFun fun args = NEUnnamed (fromNamed fun) args fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop -- | Some of the parameters have already been put in the environment; some -- haven't. Transfer all parameters to the left into the environment. -- -- [] `fromNamedFun` λx y z. E -- = []:x `fromNamedFun` λy z. E -- = []:x:y `fromNamedFun` λz. E -- = []:x:y:z `fromNamedFun` λ. E -- = []:x:y:z `fromNamedExpr` E fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun fromNamedFun env (NBody e) = fromNamedExpr env e fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t fromNamedExpr val = \case NEVar var@(Var _ ty) | Just idx <- find var val -> EVar ext ty idx | otherwise -> error "Variable out of scope in conversion from surface \ \expression to De Bruijn expression" NELet n a b -> ELet ext (go a) (lambda val n b) NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) NEPair a b -> EPair ext (go a) (go b) NEFst e -> EFst ext (go e) NESnd e -> ESnd ext (go e) NENil -> ENil ext NEInl t e -> EInl ext t (go e) NEInr t e -> EInr ext t (go e) NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) NEIdx1 a b -> EIdx1 ext (go a) (go b) NEIdx a b -> EIdx ext (go a) (go b) NEShape e -> EShape ext (go e) NEOp op e -> EOp ext op (go e) NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> ECustom ext ta tb ttape (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) NEError t s -> EError t s NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args where go :: NExpr env t' -> Ex (UnName env) t' go = fromNamedExpr val find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') find _ NTop = Nothing find var@(Var s ty) (val' `NPush` Var s' ty') | Just Refl <- testEquality s s' , Just Refl <- testEquality ty ty' = Just IZ | otherwise = IS <$> find var val' lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b lambda val' var e = fromNamedExpr (val' `NPush` var) e lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t injectWrapLet e SNil = e injectWrapLet e (arg `SCons` args) = injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) args dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) dropNth SZ (val `NPush` _) = val dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p dropNth _ NTop = error "DropNth: index out of range" dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range"