diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Drev/Types | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Drev/Types')
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs new file mode 100644 index 0000000..019119c --- /dev/null +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE GADTs #-} +module CHAD.Drev.Types.ToTan where + +import Data.Bifunctor (bimap) + +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep + + +toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) +toTanE SNil SNil SNil = SNil +toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + +toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) +toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal + STEither t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" |
