summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-10-07 14:34:27 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-10-07 14:34:27 +0200
commit72eddb67bb6f048fc2076184be3a32169026a832 (patch)
tree2f5ca7511a798d7329b12d499f4dea7239b39c50 /src
parent948cae3ca7279040627db393e4372a668f8a22f7 (diff)
Towards a test suite
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs7
-rw-r--r--src/CHAD.hs5
-rw-r--r--src/CHAD/Types.hs5
-rw-r--r--src/ForwardAD.hs15
4 files changed, 28 insertions, 4 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 8dfea68..5dab62f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -129,6 +129,13 @@ tTup = mkTup STNil STPair
eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup = mkTup (ENil ext) (EPair ext)
+unTup :: (forall a b. c (TPair a b) -> (c a, c b))
+ -> SList f list -> c (Tup list) -> SList c list
+unTup _ SNil _ = SNil
+unTup unpack (_ `SCons` list) tup =
+ let (xs, x) = unpack tup
+ in x `SCons` unTup unpack list xs
+
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
diff --git a/src/CHAD.hs b/src/CHAD.hs
index bcc1485..55d94b1 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -25,6 +25,7 @@ module CHAD (
freezeRet,
Storage(..),
Descr(..),
+ Select,
) where
import Data.Bifunctor (first, second)
@@ -630,10 +631,6 @@ sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
-d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (SCons t ts) = SCons (d2 t) (d2e ts)
-
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 0b32393..0b73a3a 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -4,6 +4,7 @@
module CHAD.Types where
import AST.Types
+import Data
type family D1 t where
@@ -63,3 +64,7 @@ d2 (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
d2 STAccum{} = error "Accumulators not allowed in input program"
+
+d2e :: SList STy env -> SList STy (D2E env)
+d2e SNil = SNil
+d2e (SCons t ts) = SCons (d2 t) (d2e ts)
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 63244a8..6d53b48 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -52,6 +52,21 @@ tanty (STScal t) = case t of
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+zeroTan :: STy t -> Rep t -> Rep (Tan t)
+zeroTan STNil () = ()
+zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
+zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
+zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STMaybe _) Nothing = Nothing
+zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
+zeroTan (STArr _ t) x = fmap (zeroTan t) x
+zeroTan (STScal STI32) _ = ()
+zeroTan (STScal STI64) _ = ()
+zeroTan (STScal STF32) _ = 0.0
+zeroTan (STScal STF64) _ = 0.0
+zeroTan (STScal STBool) _ = ()
+zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+
unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =