summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 22:59:12 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 22:59:12 +0100
commitd4d4473ee229674f73929c0860a7e29302330361 (patch)
tree65efd81ff635bad0e85cc518dc67c7c204313c73
parent34887168c0e2deb549e0e7c77e837ab269d894a2 (diff)
Cleanup, more Language operations
-rw-r--r--src/AST.hs6
-rw-r--r--src/Example.hs2
-rw-r--r--src/Language.hs43
-rw-r--r--src/Language/AST.hs21
4 files changed, 61 insertions, 11 deletions
diff --git a/src/AST.hs b/src/AST.hs
index e3da634..263b806 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -104,9 +104,9 @@ data Expr x env t where
-- ECustom does not allow a derivative to be generated for 'a', and hence
-- none is propagated.
ECustom :: x t -> STy a -> STy b -> STy tape
- -> Expr x '[b, a] t -- ^ regular operation
- -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
- -> Expr x '[D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr x [b, a] t -- ^ regular operation
+ -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
-> Expr x env a -> Expr x env b
-> Expr x env t
diff --git a/src/Example.hs b/src/Example.hs
index 1775bb9..697c4d9 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -109,7 +109,7 @@ ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 = fromNamed $ lambda #x $ lambda #n $ body $
let_ #a (unit #x) $
let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
- idx0 (#b .! 3)
+ #b ! pair nil 3
type R = TScal TF64
diff --git a/src/Language.hs b/src/Language.hs
index aa55140..7aceee7 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -33,6 +33,10 @@ lambda = NLam
inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
inline = inlineNFun
+-- To be used to construct the argument list for 'inline'.
+--
+-- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y
+-- > in inline fun (SNil .$ 16 .$ 26)
(.$) :: SList f list -> f a -> SList f (a : list)
(.$) = flip SCons
@@ -52,11 +56,11 @@ snd_ = NESnd
nil :: NExpr env TNil
nil = NENil
-inl :: STy b -> NExpr env a -> NExpr env (TEither a b)
-inl = NEInl
+inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b)
+inl = NEInl knownTy
-inr :: STy a -> NExpr env b -> NExpr env (TEither a b)
-inr = NEInr
+inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b)
+inr = NEInr knownTy
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
@@ -84,6 +88,16 @@ build2 a1 a2 (v1 :-> v2 :-> b) =
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
+map_ :: forall n a b env name. (KnownNat n, KnownTy a)
+ => (Var name a :-> NExpr ('(name, a) : env) b)
+ -> NExpr env (TArr n a) -> NExpr env (TArr n b)
+map_ (v :-> a) b
+ | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) =
+ let_ #arg b $
+ build knownNat (shape #arg) $ #i :->
+ let_ v (#arg ! #i) $
+ NEDrop (SS SZ) (NEDrop (SS SZ) a)
+
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3
@@ -111,9 +125,9 @@ const_ x =
idx0 :: NExpr env (TArr Z t) -> NExpr env t
idx0 = NEIdx0
-(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
-(.!) = NEIdx1
-infixl 9 .!
+-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
+-- (.!) = NEIdx1
+-- infixl 9 .!
(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx
@@ -125,6 +139,9 @@ shape = NEShape
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
+oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t
+oper2 op a b = NEOp op (pair a b)
+
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
@@ -159,6 +176,18 @@ infix 4 .>=
not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot
+and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+and_ = oper2 OAnd
+
+or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
+or_ = oper2 OOr
+
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
+
+round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64)
+round_ = oper ORound64
+
+toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64)
+toFloat_ = oper OToFl64
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 0ed4e51..8c91d59 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -98,6 +98,27 @@ instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st
in case scalRepIsShow ty of
Dict -> NEConst ty . fromInteger
+instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
+ => Fractional (NExpr env t) where
+ recip e = NEOp (ORecip knownScalTy) e
+ fromRational =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty . fromRational
+
+instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
+ => Floating (NExpr env t) where
+ pi =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConst ty pi
+ exp = NEOp (OExp knownScalTy)
+ log = NEOp (OExp knownScalTy)
+ sin = undefined ; cos = undefined ; tan = undefined
+ asin = undefined ; acos = undefined ; atan = undefined
+ sinh = undefined ; cosh = undefined
+ asinh = undefined ; acosh = undefined ; atanh = undefined
+
instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
fromLabel = Var symbolSing knownTy