diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Language/AST.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Language/AST.hs')
| -rw-r--r-- | src/Language/AST.hs | 300 |
1 files changed, 0 insertions, 300 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs deleted file mode 100644 index 3d6ede5..0000000 --- a/src/Language/AST.hs +++ /dev/null @@ -1,300 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module Language.AST where - -import Data.Kind (Type) -import Data.Type.Equality -import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) - -import Array -import AST -import AST.Sparse.Types -import CHAD.Types -import Data - - -type NExpr :: [(Symbol, Ty)] -> Ty -> Type -data NExpr env t where - -- lambda calculus - NEVar :: Lookup name env ~ t => Var name t -> NExpr env t - NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t - - -- environment management - NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t - - -- base types - NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) - NEFst :: NExpr env (TPair a b) -> NExpr env a - NESnd :: NExpr env (TPair a b) -> NExpr env b - NENil :: NExpr env TNil - NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) - NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) - NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c - NENothing :: STy t -> NExpr env (TMaybe t) - NEJust :: NExpr env t -> NExpr env (TMaybe t) - NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b - - -- array operations - NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) - NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) - NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) - NEUnit :: NExpr env t -> NExpr env (TArr Z t) - NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) - NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) - NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) - NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) - - NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) - -> NExpr env t1 - -> NExpr env (TArr (S n) t1) - -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) - NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) - -> NExpr env (TArr (S n) b) - -> NExpr env (TArr n t2) - -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) - - -- expression operations - NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) - NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t - NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) - NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t - NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) - NEOp :: SOp a t -> NExpr env a -> NExpr env t - - -- custom derivatives - NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation - -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative - -> NExpr env a -> NExpr env b - -> NExpr env t - - -- fake halfway checkpointing - NERecompute :: NExpr env t -> NExpr env t - - -- accumulation effect on monoids - NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil - - -- partiality - NEError :: STy a -> String -> NExpr env a - - -- embedded unnamed expressions - NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t -deriving instance Show (NExpr env t) - -type Lookup name env = Lookup1 (name == "_") name env -type family Lookup1 eqblank name env where - Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup1 False name env = Lookup2 name env -type family Lookup2 name env where - Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env -type family Lookup3 eq t name env where - Lookup3 True t _ _ = t - Lookup3 False _ name env = Lookup2 name env - -type family DropNth i env where - DropNth Z (_ : env) = env - DropNth (S i) (p : env) = p : DropNth i env - -data Var name t = Var (SSymbol name) (STy t) - deriving (Show) - -instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where - a + b = NEOp (OAdd knownScalTy) (NEPair a b) - a * b = NEOp (OMul knownScalTy) (NEPair a b) - negate e = NEOp (ONeg knownScalTy) e - abs = error "abs undefined for NExpr" - signum = error "signum undefined for NExpr" - fromInteger = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty . fromInteger - -instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) - => Fractional (NExpr env t) where - recip e = NEOp (ORecip knownScalTy) e - fromRational = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty . fromRational - -instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) - => Floating (NExpr env t) where - pi = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty pi - exp = NEOp (OExp knownScalTy) - log = NEOp (OExp knownScalTy) - sin = undefined ; cos = undefined ; tan = undefined - asin = undefined ; acos = undefined ; atan = undefined - sinh = undefined ; cosh = undefined - asinh = undefined ; acosh = undefined ; atanh = undefined - -instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where - fromLabel = Var symbolSing knownTy - -instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where - fromLabel = NEVar (fromLabel @name) - --- | Innermost variable variable on the outside, on the right. -data NEnv env where - NTop :: NEnv '[] - NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) - --- | First (outermost) parameter on the outside, on the left. --- * env: environment of this function (grows as you go deeper inside lambdas) --- * env': environment of the body of the function --- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list -data NFun env env' t where - NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t - NBody :: NExpr env' t -> NFun env' env' t - -type family UnName env where - UnName '[] = '[] - UnName ('(name, t) : env) = t : UnName env - -envFromNEnv :: NEnv env -> SList STy (UnName env) -envFromNEnv NTop = SNil -envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env - -inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t -inlineNFun fun args = NEUnnamed (fromNamed fun) args - -fromNamed :: NFun '[] env t -> Ex (UnName env) t -fromNamed = fromNamedFun NTop - --- | Some of the parameters have already been put in the environment; some --- haven't. Transfer all parameters to the left into the environment. --- --- [] `fromNamedFun` λx y z. E --- = []:x `fromNamedFun` λy z. E --- = []:x:y `fromNamedFun` λz. E --- = []:x:y:z `fromNamedFun` λ. E --- = []:x:y:z `fromNamedExpr` E -fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t -fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun -fromNamedFun env (NBody e) = fromNamedExpr env e - -fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t -fromNamedExpr val = \case - NEVar var@(Var _ ty) - | Just idx <- find var val -> EVar ext ty idx - | otherwise -> error "Variable out of scope in conversion from surface \ - \expression to De Bruijn expression" - NELet n a b -> ELet ext (go a) (lambda val n b) - - NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) - - NEPair a b -> EPair ext (go a) (go b) - NEFst e -> EFst ext (go e) - NESnd e -> ESnd ext (go e) - NENil -> ENil ext - NEInl t e -> EInl ext t (go e) - NEInr t e -> EInr ext t (go e) - NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) - NENothing t -> ENothing ext t - NEJust e -> EJust ext (go e) - NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) - - NEConstArr n t x -> EConstArr ext n t x - NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEMap n a b -> EMap ext (lambda val n a) (go b) - NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) - NESum1Inner e -> ESum1Inner ext (go e) - NEUnit e -> EUnit ext (go e) - NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) - NEMaximum1Inner e -> EMaximum1Inner ext (go e) - NEMinimum1Inner e -> EMinimum1Inner ext (go e) - NEReshape n a b -> EReshape ext n (go a) (go b) - - NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) - NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) - - NEConst t x -> EConst ext t x - NEIdx0 e -> EIdx0 ext (go e) - NEIdx1 a b -> EIdx1 ext (go a) (go b) - NEIdx a b -> EIdx ext (go a) (go b) - NEShape e -> EShape ext (go e) - NEOp op e -> EOp ext op (go e) - - NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> - ECustom ext ta tb ttape - (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) - (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) - (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) - (go e1) (go e2) - NERecompute e -> ERecompute ext (go e) - - NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) - - NEError t s -> EError ext t s - - NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args - where - go :: NExpr env t' -> Ex (UnName env) t' - go = fromNamedExpr val - - find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') - find _ NTop = Nothing - find var@(Var s ty) (val' `NPush` Var s' ty') - | Just Refl <- testEquality s s' - , Just Refl <- testEquality ty ty' - = Just IZ - | otherwise - = IS <$> find var val' - - lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b - lambda val' var e = fromNamedExpr (val' `NPush` var) e - - lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c - lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e - - injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t - injectWrapLet e SNil = e - injectWrapLet e (arg `SCons` args) = - injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) - args - -dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) -dropNth SZ (val `NPush` _) = val -dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p -dropNth _ NTop = error "DropNth: index out of range" - -dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env -dropNthW SZ (_ `NPush` _) = WSink -dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) -dropNthW _ NTop = error "DropNth: index out of range" - -assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r -assertSymbolNotUnderscore s@SSymbol k = - case symbolVal s of - "_" -> error "assertSymbolNotUnderscore: was underscore" - _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k - -assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r -assertSymbolDistinct s1@SSymbol s2@SSymbol k - | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" - | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k - -equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r -equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k |
