From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/CHAD/Language/AST.hs | 300 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 src/CHAD/Language/AST.hs (limited to 'src/CHAD/Language/AST.hs') diff --git a/src/CHAD/Language/AST.hs b/src/CHAD/Language/AST.hs new file mode 100644 index 0000000..b270844 --- /dev/null +++ b/src/CHAD/Language/AST.hs @@ -0,0 +1,300 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.Language.AST where + +import Data.Kind (Type) +import Data.Type.Equality +import GHC.OverloadedLabels +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types + + +type NExpr :: [(Symbol, Ty)] -> Ty -> Type +data NExpr env t where + -- lambda calculus + NEVar :: Lookup name env ~ t => Var name t -> NExpr env t + NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t + + -- environment management + NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t + + -- base types + NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) + NEFst :: NExpr env (TPair a b) -> NExpr env a + NESnd :: NExpr env (TPair a b) -> NExpr env b + NENil :: NExpr env TNil + NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) + NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) + NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c + NENothing :: STy t -> NExpr env (TMaybe t) + NEJust :: NExpr env t -> NExpr env (TMaybe t) + NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b + + -- array operations + NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) + NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) + NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEUnit :: NExpr env t -> NExpr env (TArr Z t) + NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) + NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) + + -- expression operations + NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) + NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t + NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) + NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t + NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) + NEOp :: SOp a t -> NExpr env a -> NExpr env t + + -- custom derivatives + NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation + -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative + -> NExpr env a -> NExpr env b + -> NExpr env t + + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + + -- accumulation effect on monoids + NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil + + -- partiality + NEError :: STy a -> String -> NExpr env a + + -- embedded unnamed expressions + NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t +deriving instance Show (NExpr env t) + +type Lookup name env = Lookup1 (name == "_") name env +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env + +type family DropNth i env where + DropNth Z (_ : env) = env + DropNth (S i) (p : env) = p : DropNth i env + +data Var name t = Var (SSymbol name) (STy t) + deriving (Show) + +instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where + a + b = NEOp (OAdd knownScalTy) (NEPair a b) + a * b = NEOp (OMul knownScalTy) (NEPair a b) + negate e = NEOp (ONeg knownScalTy) e + abs = error "abs undefined for NExpr" + signum = error "signum undefined for NExpr" + fromInteger = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty . fromInteger + +instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) + => Fractional (NExpr env t) where + recip e = NEOp (ORecip knownScalTy) e + fromRational = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty . fromRational + +instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) + => Floating (NExpr env t) where + pi = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty pi + exp = NEOp (OExp knownScalTy) + log = NEOp (OExp knownScalTy) + sin = undefined ; cos = undefined ; tan = undefined + asin = undefined ; acos = undefined ; atan = undefined + sinh = undefined ; cosh = undefined + asinh = undefined ; acosh = undefined ; atanh = undefined + +instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where + fromLabel = Var symbolSing knownTy + +instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where + fromLabel = NEVar (fromLabel @name) + +-- | Innermost variable variable on the outside, on the right. +data NEnv env where + NTop :: NEnv '[] + NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) + +-- | First (outermost) parameter on the outside, on the left. +-- * env: environment of this function (grows as you go deeper inside lambdas) +-- * env': environment of the body of the function +-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list +data NFun env env' t where + NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t + NBody :: NExpr env' t -> NFun env' env' t + +type family UnName env where + UnName '[] = '[] + UnName ('(name, t) : env) = t : UnName env + +envFromNEnv :: NEnv env -> SList STy (UnName env) +envFromNEnv NTop = SNil +envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env + +inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t +inlineNFun fun args = NEUnnamed (fromNamed fun) args + +fromNamed :: NFun '[] env t -> Ex (UnName env) t +fromNamed = fromNamedFun NTop + +-- | Some of the parameters have already been put in the environment; some +-- haven't. Transfer all parameters to the left into the environment. +-- +-- [] `fromNamedFun` λx y z. E +-- = []:x `fromNamedFun` λy z. E +-- = []:x:y `fromNamedFun` λz. E +-- = []:x:y:z `fromNamedFun` λ. E +-- = []:x:y:z `fromNamedExpr` E +fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t +fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun +fromNamedFun env (NBody e) = fromNamedExpr env e + +fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t +fromNamedExpr val = \case + NEVar var@(Var _ ty) + | Just idx <- find var val -> EVar ext ty idx + | otherwise -> error "Variable out of scope in conversion from surface \ + \expression to De Bruijn expression" + NELet n a b -> ELet ext (go a) (lambda val n b) + + NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) + + NEPair a b -> EPair ext (go a) (go b) + NEFst e -> EFst ext (go e) + NESnd e -> ESnd ext (go e) + NENil -> ENil ext + NEInl t e -> EInl ext t (go e) + NEInr t e -> EInr ext t (go e) + NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) + NENothing t -> ENothing ext t + NEJust e -> EJust ext (go e) + NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) + + NEConstArr n t x -> EConstArr ext n t x + NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) + NEMap n a b -> EMap ext (lambda val n a) (go b) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) + NESum1Inner e -> ESum1Inner ext (go e) + NEUnit e -> EUnit ext (go e) + NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) + NEMaximum1Inner e -> EMaximum1Inner ext (go e) + NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) + + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + + NEConst t x -> EConst ext t x + NEIdx0 e -> EIdx0 ext (go e) + NEIdx1 a b -> EIdx1 ext (go a) (go b) + NEIdx a b -> EIdx ext (go a) (go b) + NEShape e -> EShape ext (go e) + NEOp op e -> EOp ext op (go e) + + NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> + ECustom ext ta tb ttape + (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) + (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) + (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) + (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) + + NEWith t a n b -> EWith ext t (go a) (lambda val n b) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) + + NEError t s -> EError ext t s + + NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args + where + go :: NExpr env t' -> Ex (UnName env) t' + go = fromNamedExpr val + + find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') + find _ NTop = Nothing + find var@(Var s ty) (val' `NPush` Var s' ty') + | Just Refl <- testEquality s s' + , Just Refl <- testEquality ty ty' + = Just IZ + | otherwise + = IS <$> find var val' + + lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b + lambda val' var e = fromNamedExpr (val' `NPush` var) e + + lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c + lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e + + injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t + injectWrapLet e SNil = e + injectWrapLet e (arg `SCons` args) = + injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) + args + +dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) +dropNth SZ (val `NPush` _) = val +dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p +dropNth _ NTop = error "DropNth: index out of range" + +dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env +dropNthW SZ (_ `NPush` _) = WSink +dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) +dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k -- cgit v1.2.3-70-g09d2