diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 22:59:12 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 22:59:12 +0100 |
commit | d4d4473ee229674f73929c0860a7e29302330361 (patch) | |
tree | 65efd81ff635bad0e85cc518dc67c7c204313c73 | |
parent | 34887168c0e2deb549e0e7c77e837ab269d894a2 (diff) |
Cleanup, more Language operations
-rw-r--r-- | src/AST.hs | 6 | ||||
-rw-r--r-- | src/Example.hs | 2 | ||||
-rw-r--r-- | src/Language.hs | 43 | ||||
-rw-r--r-- | src/Language/AST.hs | 21 |
4 files changed, 61 insertions, 11 deletions
@@ -104,9 +104,9 @@ data Expr x env t where -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x '[b, a] t -- ^ regular operation - -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x '[D2 t, tape] (D2 b) -- ^ CHAD reverse derivative + -> Expr x [b, a] t -- ^ regular operation + -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative -> Expr x env a -> Expr x env b -> Expr x env t diff --git a/src/Example.hs b/src/Example.hs index 1775bb9..697c4d9 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -109,7 +109,7 @@ ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #a (unit #x) $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ - idx0 (#b .! 3) + #b ! pair nil 3 type R = TScal TF64 diff --git a/src/Language.hs b/src/Language.hs index aa55140..7aceee7 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -33,6 +33,10 @@ lambda = NLam inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t inline = inlineNFun +-- To be used to construct the argument list for 'inline'. +-- +-- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y +-- > in inline fun (SNil .$ 16 .$ 26) (.$) :: SList f list -> f a -> SList f (a : list) (.$) = flip SCons @@ -52,11 +56,11 @@ snd_ = NESnd nil :: NExpr env TNil nil = NENil -inl :: STy b -> NExpr env a -> NExpr env (TEither a b) -inl = NEInl +inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) +inl = NEInl knownTy -inr :: STy a -> NExpr env b -> NExpr env (TEither a b) -inr = NEInr +inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) +inr = NEInr knownTy case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 @@ -84,6 +88,16 @@ build2 a1 a2 (v1 :-> v2 :-> b) = build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) build n a (v :-> b) = NEBuild n a v b +map_ :: forall n a b env name. (KnownNat n, KnownTy a) + => (Var name a :-> NExpr ('(name, a) : env) b) + -> NExpr env (TArr n a) -> NExpr env (TArr n b) +map_ (v :-> a) b + | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) = + let_ #arg b $ + build knownNat (shape #arg) $ #i :-> + let_ v (#arg ! #i) $ + NEDrop (SS SZ) (NEDrop (SS SZ) a) + fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 @@ -111,9 +125,9 @@ const_ x = idx0 :: NExpr env (TArr Z t) -> NExpr env t idx0 = NEIdx0 -(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) -(.!) = NEIdx1 -infixl 9 .! +-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +-- (.!) = NEIdx1 +-- infixl 9 .! (!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t (!) = NEIdx @@ -125,6 +139,9 @@ shape = NEShape oper :: SOp a t -> NExpr env a -> NExpr env t oper = NEOp +oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t +oper2 op a b = NEOp op (pair a b) + error_ :: KnownTy t => String -> NExpr env t error_ s = NEError knownTy s @@ -159,6 +176,18 @@ infix 4 .>= not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) not_ = oper ONot +and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +and_ = oper2 OAnd + +or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +or_ = oper2 OOr + -- | The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) + +round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) +round_ = oper ORound64 + +toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) +toFloat_ = oper OToFl64 diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 0ed4e51..8c91d59 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -98,6 +98,27 @@ instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st in case scalRepIsShow ty of Dict -> NEConst ty . fromInteger +instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) + => Fractional (NExpr env t) where + recip e = NEOp (ORecip knownScalTy) e + fromRational = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty . fromRational + +instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) + => Floating (NExpr env t) where + pi = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty pi + exp = NEOp (OExp knownScalTy) + log = NEOp (OExp knownScalTy) + sin = undefined ; cos = undefined ; tan = undefined + asin = undefined ; acos = undefined ; atan = undefined + sinh = undefined ; cosh = undefined + asinh = undefined ; acosh = undefined ; atanh = undefined + instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where fromLabel = Var symbolSing knownTy |