aboutsummaryrefslogtreecommitdiff
path: root/src/Language
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Language
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Language')
-rw-r--r--src/Language/AST.hs300
1 files changed, 0 insertions, 300 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
deleted file mode 100644
index 3d6ede5..0000000
--- a/src/Language/AST.hs
+++ /dev/null
@@ -1,300 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE MultiParamTypeClasses #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-module Language.AST where
-
-import Data.Kind (Type)
-import Data.Type.Equality
-import GHC.OverloadedLabels
-import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal)
-
-import Array
-import AST
-import AST.Sparse.Types
-import CHAD.Types
-import Data
-
-
-type NExpr :: [(Symbol, Ty)] -> Ty -> Type
-data NExpr env t where
- -- lambda calculus
- NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
- NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
-
- -- environment management
- NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t
-
- -- base types
- NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
- NEFst :: NExpr env (TPair a b) -> NExpr env a
- NESnd :: NExpr env (TPair a b) -> NExpr env b
- NENil :: NExpr env TNil
- NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
- NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
- NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c
- NENothing :: STy t -> NExpr env (TMaybe t)
- NEJust :: NExpr env t -> NExpr env (TMaybe t)
- NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b
-
- -- array operations
- NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
- NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
- NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t)
- NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
- NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
- NEUnit :: NExpr env t -> NExpr env (TArr Z t)
- NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
- NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
- NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
- NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
-
- NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b)
- -> NExpr env t1
- -> NExpr env (TArr (S n) t1)
- -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
- NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
- -> NExpr env (TArr (S n) b)
- -> NExpr env (TArr n t2)
- -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
-
- -- expression operations
- NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
- NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
- NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
- NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
- NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
- NEOp :: SOp a t -> NExpr env a -> NExpr env t
-
- -- custom derivatives
- NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation
- -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass
- -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative
- -> NExpr env a -> NExpr env b
- -> NExpr env t
-
- -- fake halfway checkpointing
- NERecompute :: NExpr env t -> NExpr env t
-
- -- accumulation effect on monoids
- NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
- NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
-
- -- partiality
- NEError :: STy a -> String -> NExpr env a
-
- -- embedded unnamed expressions
- NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
-deriving instance Show (NExpr env t)
-
-type Lookup name env = Lookup1 (name == "_") name env
-type family Lookup1 eqblank name env where
- Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
- Lookup1 False name env = Lookup2 name env
-type family Lookup2 name env where
- Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
- Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env
-type family Lookup3 eq t name env where
- Lookup3 True t _ _ = t
- Lookup3 False _ name env = Lookup2 name env
-
-type family DropNth i env where
- DropNth Z (_ : env) = env
- DropNth (S i) (p : env) = p : DropNth i env
-
-data Var name t = Var (SSymbol name) (STy t)
- deriving (Show)
-
-instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
- a + b = NEOp (OAdd knownScalTy) (NEPair a b)
- a * b = NEOp (OMul knownScalTy) (NEPair a b)
- negate e = NEOp (ONeg knownScalTy) e
- abs = error "abs undefined for NExpr"
- signum = error "signum undefined for NExpr"
- fromInteger =
- let ty = knownScalTy
- in case scalRepIsShow ty of
- Dict -> NEConst ty . fromInteger
-
-instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
- => Fractional (NExpr env t) where
- recip e = NEOp (ORecip knownScalTy) e
- fromRational =
- let ty = knownScalTy
- in case scalRepIsShow ty of
- Dict -> NEConst ty . fromRational
-
-instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
- => Floating (NExpr env t) where
- pi =
- let ty = knownScalTy
- in case scalRepIsShow ty of
- Dict -> NEConst ty pi
- exp = NEOp (OExp knownScalTy)
- log = NEOp (OExp knownScalTy)
- sin = undefined ; cos = undefined ; tan = undefined
- asin = undefined ; acos = undefined ; atan = undefined
- sinh = undefined ; cosh = undefined
- asinh = undefined ; acosh = undefined ; atanh = undefined
-
-instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
- fromLabel = Var symbolSing knownTy
-
-instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
- fromLabel = NEVar (fromLabel @name)
-
--- | Innermost variable variable on the outside, on the right.
-data NEnv env where
- NTop :: NEnv '[]
- NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)
-
--- | First (outermost) parameter on the outside, on the left.
--- * env: environment of this function (grows as you go deeper inside lambdas)
--- * env': environment of the body of the function
--- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
-data NFun env env' t where
- NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
- NBody :: NExpr env' t -> NFun env' env' t
-
-type family UnName env where
- UnName '[] = '[]
- UnName ('(name, t) : env) = t : UnName env
-
-envFromNEnv :: NEnv env -> SList STy (UnName env)
-envFromNEnv NTop = SNil
-envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env
-
-inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
-inlineNFun fun args = NEUnnamed (fromNamed fun) args
-
-fromNamed :: NFun '[] env t -> Ex (UnName env) t
-fromNamed = fromNamedFun NTop
-
--- | Some of the parameters have already been put in the environment; some
--- haven't. Transfer all parameters to the left into the environment.
---
--- [] `fromNamedFun` λx y z. E
--- = []:x `fromNamedFun` λy z. E
--- = []:x:y `fromNamedFun` λz. E
--- = []:x:y:z `fromNamedFun` λ. E
--- = []:x:y:z `fromNamedExpr` E
-fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
-fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
-fromNamedFun env (NBody e) = fromNamedExpr env e
-
-fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
-fromNamedExpr val = \case
- NEVar var@(Var _ ty)
- | Just idx <- find var val -> EVar ext ty idx
- | otherwise -> error "Variable out of scope in conversion from surface \
- \expression to De Bruijn expression"
- NELet n a b -> ELet ext (go a) (lambda val n b)
-
- NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)
-
- NEPair a b -> EPair ext (go a) (go b)
- NEFst e -> EFst ext (go e)
- NESnd e -> ESnd ext (go e)
- NENil -> ENil ext
- NEInl t e -> EInl ext t (go e)
- NEInr t e -> EInr ext t (go e)
- NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
- NENothing t -> ENothing ext t
- NEJust e -> EJust ext (go e)
- NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c)
-
- NEConstArr n t x -> EConstArr ext n t x
- NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEMap n a b -> EMap ext (lambda val n a) (go b)
- NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
- NESum1Inner e -> ESum1Inner ext (go e)
- NEUnit e -> EUnit ext (go e)
- NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
- NEMaximum1Inner e -> EMaximum1Inner ext (go e)
- NEMinimum1Inner e -> EMinimum1Inner ext (go e)
- NEReshape n a b -> EReshape ext n (go a) (go b)
-
- NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c)
- NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
-
- NEConst t x -> EConst ext t x
- NEIdx0 e -> EIdx0 ext (go e)
- NEIdx1 a b -> EIdx1 ext (go a) (go b)
- NEIdx a b -> EIdx ext (go a) (go b)
- NEShape e -> EShape ext (go e)
- NEOp op e -> EOp ext op (go e)
-
- NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
- ECustom ext ta tb ttape
- (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
- (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
- (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
- (go e1) (go e2)
- NERecompute e -> ERecompute ext (go e)
-
- NEWith t a n b -> EWith ext t (go a) (lambda val n b)
- NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
-
- NEError t s -> EError ext t s
-
- NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
- where
- go :: NExpr env t' -> Ex (UnName env) t'
- go = fromNamedExpr val
-
- find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
- find _ NTop = Nothing
- find var@(Var s ty) (val' `NPush` Var s' ty')
- | Just Refl <- testEquality s s'
- , Just Refl <- testEquality ty ty'
- = Just IZ
- | otherwise
- = IS <$> find var val'
-
- lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
- lambda val' var e = fromNamedExpr (val' `NPush` var) e
-
- lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
- lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e
-
- injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
- injectWrapLet e SNil = e
- injectWrapLet e (arg `SCons` args) =
- injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
- args
-
-dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
-dropNth SZ (val `NPush` _) = val
-dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
-dropNth _ NTop = error "DropNth: index out of range"
-
-dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
-dropNthW SZ (_ `NPush` _) = WSink
-dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
-dropNthW _ NTop = error "DropNth: index out of range"
-
-assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r
-assertSymbolNotUnderscore s@SSymbol k =
- case symbolVal s of
- "_" -> error "assertSymbolNotUnderscore: was underscore"
- _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k
-
-assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r
-assertSymbolDistinct s1@SSymbol s2@SSymbol k
- | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")"
- | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k
-
-equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r
-equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k