diff options
Diffstat (limited to 'src/CHAD/Language.hs')
| -rw-r--r-- | src/CHAD/Language.hs | 423 |
1 files changed, 423 insertions, 0 deletions
diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs new file mode 100644 index 0000000..6621eef --- /dev/null +++ b/src/CHAD/Language.hs @@ -0,0 +1,423 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Language ( + -- * Named expressions + fromNamed, + NExpr, NFun, + + -- * Functions + lambda, + body, + inline, + (.$), + + -- * Basic language constructs + let_, + pair, fst_, snd_, nil, + inl, inr, case_, + nothing, just, maybe_, + + -- * Array operations + constArr_, + build1, build2, build, + map_, + fold1i, fold1i', + sum1i, + unit, + replicate1i, + maximum1i, minimum1i, + reshape, + fold1iD1, fold1iD1', + fold1iD2, + + -- * Scalar operations + -- | Note that 'NExpr' is also an instance of some numeric classes like 'Num' and 'Floating'. + const_, + idx0, + (!), + shape, + length_, + error_, + (.==), (.<), (CHAD.Language..>), (.<=), (.>=), + not_, and_, or_, + mod_, round_, toFloat_, idiv, + + -- * Control flow + if_, + + -- * Special operations + custom, + recompute, + with, accum, accumS, + oper, oper2, + + -- * Helper types + (:->)(..), + + -- * Reexports + TIx, + Lookup, + Ex, + Ty(..), + SNat(..), Nat(..), N0, N1, N2, N3, +) where + +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.Language.AST + + +-- | Helper type, used for e.g. 'case_' and 'build'. +data a :-> b = a :-> b + deriving (Show) +infixr 0 :-> + + +-- | See 'fromNamed' for a usage example. +body :: NExpr env t -> NFun env env t +body = NBody + +-- | See 'fromNamed' for a usage example. +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda = NLam + +-- | Inline a function here, with the given list of expressions as arguments. +-- While this is a normal 'SList', the @params@ list is reversed from the +-- natural argument order of the function; the '(.$)' helper operator serves to +-- "fix" the order. +-- +-- @ +-- let fun = 'lambda' \@(TScal TF64) #x $ 'lambda' \@(TScal TBool) #b $ 'body' $ if_ #b #x (#x + 1) +-- in 'inline' fun ('SNil' .$ 16 .$ 'const_' True) +-- @ +-- +-- Note that no 'const_' is needed for the @16@, because 'NExpr' implements +-- 'Num'. +inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t +inline = inlineNFun + +-- | Helper for constructing the argument list for 'inline'; +-- @(.$) = flip 'SCons'@. See 'inline'. +(.$) :: SList f list -> f a -> SList f (a : list) +(.$) = flip SCons + + +-- | The first 'Var' argument is the left-hand side of this let-binding. For example: +-- +-- @ +-- 'fromNamed' $ 'lambda' \@(TScal TI64) #a $ 'body' $ +-- 'let_' #x (#a + 1) $ +-- #x * #a +-- @ +-- +-- This produces an expression of type @'Ex' '[TScal TI64] (TScal TI64)@ that +-- corresponds to the Haskell code @\\a -> let x = a + 1 in x * a@. +let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t +let_ = NELet + +pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) +pair = NEPair + +fst_ :: NExpr env (TPair a b) -> NExpr env a +fst_ = NEFst + +snd_ :: NExpr env (TPair a b) -> NExpr env b +snd_ = NESnd + +nil :: NExpr env TNil +nil = NENil + +inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) +inl = NEInl knownTy + +inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) +inr = NEInr knownTy + +-- | A @case@ expression on @Either@s. For example, the following expression +-- will evaluate to 10 + 1 = 11: +-- +-- @ +-- 'case_' ('inl' 10) +-- (#x :-> #x + 1) +-- (#y :-> #y * 2) +-- @ +case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c +case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 + +nothing :: KnownTy a => NExpr env (TMaybe a) +nothing = NENothing knownTy + +just :: NExpr env a -> NExpr env (TMaybe a) +just = NEJust + +-- | Analogue of the 'Prelude.maybe' function in the Haskell Prelude: +-- +-- @ +-- 'maybe_' 2 (#x :-> #x * 3) (...) +-- @ +-- +-- will return 2 if @(...)@ is @Nothing@ and @x + 3@ if it is @Just x@. +maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b +maybe_ a (v :-> b) c = NEMaybe a v b c + +-- | To construct 'Array' values, see "CHAD.Array". +constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConstArr knownNat ty x + +-- | Special case of 'build' for 1-dimensional arrays. This produces the array +-- [0.0, 1.0, 2.0]: +-- +-- @ +-- 'build1' 3 (#i :-> 'toFloat_' #i) +-- @ +build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) +build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) + +-- | Special case of 'build' for 2-dimensional arrays. +build2 :: NExpr env TIx -> NExpr env TIx + -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) + -> NExpr env (TArr (S (S Z)) t) +build2 a1 a2 (v1 :-> v2 :-> b) = + NEBuild (SS (SS SZ)) + (pair (pair nil a1) a2) + #idx + (let_ v1 (snd_ (fst_ #idx)) $ + let_ v2 (NEDrop SZ (snd_ #idx)) $ + NEDrop (SS (SS SZ)) b) + +-- | General n-dimensional elementwise array constructor. A 3-dimensional index +-- looks like @((((), i1), i2), i3)@; other dimensionalities are analogous. The +-- innermost dimension (i.e. whose index variable varies the fastest in the +-- standard memory layout) is the right-most index, i.e. @i3@ in 3D example. To +-- create a 10-by-10 table of (row, column) pairs: +-- +-- @ +-- 'build' ('SS' ('SS' 'SZ')) ('pair' ('pair' 'nil' 10) 10) (#i :-> #j :-> 'pair' #i #j) +-- @ +build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) +build n a (v :-> b) = NEBuild n a v b + +map_ :: forall n a b env name. (KnownNat n, KnownTy a) + => (Var name a :-> NExpr ('(name, a) : env) b) + -> NExpr env (TArr n a) -> NExpr env (TArr n b) +map_ (v :-> a) b = NEMap v a b + +-- | Fold over the innermost dimension of an array, thus reducing its dimensionality by one. +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | The underlying AST constructor for a fold takes a function with /one/ +-- argument: a pair of inputs. 'fold1i'' directly returns this AST constructor +-- in case it is helpful for testing. The 'fold1i' function is a convenience +-- wrapper around 'fold1i''. +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 + +sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +sum1i e = NESum1Inner e + +unit :: NExpr env t -> NExpr env (TArr Z t) +unit = NEUnit + +replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) +replicate1i n a = NEReplicate1Inner n a + +maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +maximum1i e = NEMaximum1Inner e + +minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +minimum1i e = NEMinimum1Inner e + +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + +-- | 'fold1iD1'' with a curried combination function. +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | Primal of a fold. Not supported in the input program for reverse differentiation. +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 + +-- | Reverse pass of a fold. Not supported in the input program for reverse differentiation. +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + +const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) +const_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty x + +idx0 :: NExpr env (TArr Z t) -> NExpr env t +idx0 = NEIdx0 + +-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +-- (.!) = NEIdx1 +-- infixl 9 .! + +-- | Index an array. Note that the index is a tuple, just like the argument to +-- the function in 'build'. To index a 2-dimensional array @a@ at row @i@ and +-- column @j@, write @a '!' 'pair' ('pair' 'nil' i) j@. +(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx +infixl 9 ! + +shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) +shape = NEShape + +-- | Convenience special case of 'shape' for single-dimensional arrays. +length_ :: NExpr env (TArr N1 t) -> NExpr env TIx +length_ e = snd_ (shape e) + +oper :: SOp a t -> NExpr env a -> NExpr env t +oper = NEOp + +oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t +oper2 op a b = NEOp op (pair a b) + +error_ :: KnownTy t => String -> NExpr env t +error_ s = NEError knownTy s + +-- | Specify a custom reverse derivative for a subexpression. Morally, the type +-- of this combinator should be read as follows: +-- +-- @ +-- custom :: (a -> b -> t) -- normal semantics +-- -> (D1 a -> D1 b -> (D1 t, tape)) -- forward pass +-- -> (tape -> D2 t -> D2 b) -- reverse pass +-- -> a -> b -- arguments +-- -> t -- result +-- @ +-- +-- In normal evaluation, or when forward-differentiating, the first argument is +-- taken and the second and third are ignored. When reverse-differentiating +-- using CHAD, however, the /first/ argument is ignored and the second and +-- third arguments are respectively put in the forward and the reverse passes +-- of the derivative program. The @tape@ value may be used to remember primals +-- for the reverse pass. +-- +-- This combinator allows for "inactive" and "active" inputs to the operation; +-- derivatives to the "inactive" input are not propagated. The active input +-- (whose derivatives /are/ propagated) has type @b@; the inactive input has +-- type @a@. +-- +-- No accumulators are allowed inside @a@, @b@ and @tape@. +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + +-- | Semantically the identity, but when reverse differentiating using CHAD, +-- the contained expression is recomputed in the reverse pass. This is a +-- light-weight form of checkpointing, with the goal of reducing the number +-- primal values being stored and thus reducing memory use and memory traffic. +-- +-- Note that free variables of the contained expression do still need to be +-- stored, as we do need to be able to recompute the expression in the reverse +-- pass. +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + +-- | Introduce an accumulator. The initial value is not allowed to be sparse! +-- See 'CHAD.AST.EWith'. Not supported in the input program for reverse +-- differentiation. +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b + +-- | Accumulate to an accumulator. Not supported in the input program for +-- reverse differentiation. +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +-- | Accumulate to an accumulator with additional sparsity. Not supported in +-- the input program for reverse differentiation. +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c + + +(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .== b = oper (OEq knownScalTy) (pair a b) +infix 4 .== + +(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .< b = oper (OLt knownScalTy) (pair a b) +infix 4 .< + +(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>) = flip (.<) +infix 4 .> + +(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .<= b = oper (OLe knownScalTy) (pair a b) +infix 4 .<= + +(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>=) = flip (.<=) +infix 4 .>= + +not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) +not_ = oper ONot + +and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +and_ = oper2 OAnd +infixr 3 `and_` + +or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +or_ = oper2 OOr +infixr 2 `or_` + +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + +-- | The first alternative is the True case; the second is the False case. +if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) + +round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) +round_ = oper ORound64 + +toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) +toFloat_ = oper OToFl64 + +idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) +idiv = oper2 (OIDiv knownScalTy) +infixl 7 `idiv` |
