summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-28 23:57:31 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-28 23:57:31 +0100
commit9eec3fb3ec727e61a34742be7672a4e281127576 (patch)
treecdfc2d9225077e082e18f1d1a00ea9e3ec2deca4 /src
parentb3b7cebfac9d9c54a2e51152e60e04999a7683e3 (diff)
test: Simplify and make it a bit faster
Diffstat (limited to 'src')
-rw-r--r--src/CHAD/Types/ToTan.hs42
-rw-r--r--src/Interpreter.hs2
-rw-r--r--src/Simplify.hs2
3 files changed, 44 insertions, 2 deletions
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
new file mode 100644
index 0000000..a75fdb8
--- /dev/null
+++ b/src/CHAD/Types/ToTan.hs
@@ -0,0 +1,42 @@
+{-# LANGUAGE GADTs #-}
+module CHAD.Types.ToTan where
+
+import Data.Bifunctor (bimap)
+
+import Array
+import AST.Types
+import CHAD.Types
+import Data
+import ForwardAD
+import Interpreter.Rep
+
+
+toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
+toTanE SNil SNil SNil = SNil
+toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
+ Value (toTan t p x) `SCons` toTanE env primal inp
+
+toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
+toTan typ primal der = case typ of
+ STNil -> der
+ STPair t1 t2 -> case der of
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STEither t1 t2 -> case der of
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just d -> case (primal, d) of
+ (Left p, Left d') -> Left (toTan t1 p d')
+ (Right p, Right d') -> Right (toTan t2 p d')
+ _ -> error "Primal and cotangent disagree on Either alternative"
+ STMaybe t -> liftA2 (toTan t) primal der
+ STArr _ t
+ | shapeSize (arrayShape der) == 0 ->
+ arrayMap (zeroTan t) primal
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
+ STScal sty -> case sty of
+ STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
+ STAccum{} -> error "Accumulators not allowed in input program"
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index deb829b..dd558fe 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -558,7 +558,7 @@ tupRepIdx :: (forall m. f (S m) -> (f m, Int))
tupRepIdx _ SZ _ = ()
tupRepIdx uncons (SS n) tup =
let (tup', i) = uncons tup
- in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)
+ in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i
ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 785e2bd..673b58c 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-module Simplify where
+module Simplify (simplifyN, simplifyFix) where
import Data.Function (fix)
import Data.Monoid (Any(..))