summaryrefslogtreecommitdiff
path: root/src/CHAD/Types/ToTan.hs
blob: a75fdb8738f513d1332b609cf47037614582943c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
{-# LANGUAGE GADTs #-}
module CHAD.Types.ToTan where

import Data.Bifunctor (bimap)

import Array
import AST.Types
import CHAD.Types
import Data
import ForwardAD
import Interpreter.Rep


toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
  Value (toTan t p x) `SCons` toTanE env primal inp

toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
  STNil -> der
  STPair t1 t2 -> case der of
                    Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
                    Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
  STEither t1 t2 -> case der of
                      Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
                      Just d -> case (primal, d) of
                        (Left p, Left d') -> Left (toTan t1 p d')
                        (Right p, Right d') -> Right (toTan t2 p d')
                        _ -> error "Primal and cotangent disagree on Either alternative"
  STMaybe t -> liftA2 (toTan t) primal der
  STArr _ t
    | shapeSize (arrayShape der) == 0 ->
        arrayMap (zeroTan t) primal
    | arrayShape primal == arrayShape der ->
        arrayGenerateLin (arrayShape primal) $ \i ->
          toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
    | otherwise ->
        error "Primal and cotangent disagree on array shape"
  STScal sty -> case sty of
    STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
  STAccum{} -> error "Accumulators not allowed in input program"