diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-28 23:57:31 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-28 23:57:31 +0100 |
commit | 9eec3fb3ec727e61a34742be7672a4e281127576 (patch) | |
tree | cdfc2d9225077e082e18f1d1a00ea9e3ec2deca4 /src | |
parent | b3b7cebfac9d9c54a2e51152e60e04999a7683e3 (diff) |
test: Simplify and make it a bit faster
Diffstat (limited to 'src')
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 42 | ||||
-rw-r--r-- | src/Interpreter.hs | 2 | ||||
-rw-r--r-- | src/Simplify.hs | 2 |
3 files changed, 44 insertions, 2 deletions
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs new file mode 100644 index 0000000..a75fdb8 --- /dev/null +++ b/src/CHAD/Types/ToTan.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE GADTs #-} +module CHAD.Types.ToTan where + +import Data.Bifunctor (bimap) + +import Array +import AST.Types +import CHAD.Types +import Data +import ForwardAD +import Interpreter.Rep + + +toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) +toTanE SNil SNil SNil = SNil +toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + +toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) +toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STEither t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | shapeSize (arrayShape der) == 0 -> + arrayMap (zeroTan t) primal + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index deb829b..dd558fe 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -558,7 +558,7 @@ tupRepIdx :: (forall m. f (S m) -> (f m, Int)) tupRepIdx _ SZ _ = () tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup - in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i) + in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) diff --git a/src/Simplify.hs b/src/Simplify.hs index 785e2bd..673b58c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -7,7 +7,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Simplify where +module Simplify (simplifyN, simplifyFix) where import Data.Function (fix) import Data.Monoid (Any(..)) |