blob: 72122322dffd2ceadbe986717ba841dc19d85f7b (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
-- | TODO this module is a grab-bag of random utility functions that are shared
-- between CHAD and CHAD.Top.
module CHAD.Accum where
import AST
import CHAD.Types
import Data
import AST.Env
d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
d2zeroInfo STNil _ = ENil ext
d2zeroInfo (STPair a b) e =
eunPair e $ \_ e1 e2 ->
EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
d2zeroInfo STEither{} _ = ENil ext
d2zeroInfo STLEither{} _ = ENil ext
d2zeroInfo STMaybe{} _ = ENil ext
d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
d2deepZeroInfo STNil _ = ENil ext
d2deepZeroInfo (STPair a b) e =
eunPair e $ \_ e1 e2 ->
EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
d2deepZeroInfo (STEither a b) e =
ECase ext e
(ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
(ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
d2deepZeroInfo (STLEither a b) e =
elcase e
(ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
(ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
(ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
d2deepZeroInfo (STMaybe a) e =
emaybe e
(ENothing ext (tDeepZeroInfo (d2M a)))
(EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators _ SNil e = e
makeAccumulators w (t `SCons` envpro) e =
makeAccumulators (WPop w) envpro $
EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
uninvertTup (t `SCons` list) tcore e =
ELet ext (uninvertTup list (STPair tcore t) e) $
let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
in EPair ext
(EFst ext (EFst ext (EVar ext recT IZ)))
(EPair ext
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
subenvD1E SETop = SETop
subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
subenvD1E (SENo sub) = SENo (subenvD1E sub)
|