diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-10-07 14:34:27 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-10-07 14:34:27 +0200 |
commit | 72eddb67bb6f048fc2076184be3a32169026a832 (patch) | |
tree | 2f5ca7511a798d7329b12d499f4dea7239b39c50 /src | |
parent | 948cae3ca7279040627db393e4372a668f8a22f7 (diff) |
Towards a test suite
Diffstat (limited to 'src')
-rw-r--r-- | src/AST.hs | 7 | ||||
-rw-r--r-- | src/CHAD.hs | 5 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 5 | ||||
-rw-r--r-- | src/ForwardAD.hs | 15 |
4 files changed, 28 insertions, 4 deletions
@@ -129,6 +129,13 @@ tTup = mkTup STNil STPair eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) +unTup :: (forall a b. c (TPair a b) -> (c a, c b)) + -> SList f list -> c (Tup list) -> SList c list +unTup _ SNil _ = SNil +unTup unpack (_ `SCons` list) tup = + let (xs, x) = unpack tup + in x `SCons` unTup unpack list xs + type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts diff --git a/src/CHAD.hs b/src/CHAD.hs index bcc1485..55d94b1 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -25,6 +25,7 @@ module CHAD ( freezeRet, Storage(..), Descr(..), + Select, ) where import Data.Bifunctor (first, second) @@ -630,10 +631,6 @@ sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) -d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) - d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 0b32393..0b73a3a 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -4,6 +4,7 @@ module CHAD.Types where import AST.Types +import Data type family D1 t where @@ -63,3 +64,7 @@ d2 (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil d2 STAccum{} = error "Accumulators not allowed in input program" + +d2e :: SList STy env -> SList STy (D2E env) +d2e SNil = SNil +d2e (SCons t ts) = SCons (d2 t) (d2e ts) diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 63244a8..6d53b48 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -52,6 +52,21 @@ tanty (STScal t) = case t of STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +zeroTan :: STy t -> Rep t -> Rep (Tan t) +zeroTan STNil () = () +zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) +zeroTan (STEither a _) (Left x) = Left (zeroTan a x) +zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STMaybe _) Nothing = Nothing +zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) +zeroTan (STArr _ t) x = fmap (zeroTan t) x +zeroTan (STScal STI32) _ = () +zeroTan (STScal STI64) _ = () +zeroTan (STScal STF32) _ = 0.0 +zeroTan (STScal STF64) _ = 0.0 +zeroTan (STScal STBool) _ = () +zeroTan STAccum{} _ = error "Accumulators not allowed in input program" + unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) unzipDN STNil _ = ((), ()) unzipDN (STPair a b) (d1, d2) = |