summaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
blob: 1675dabd28fe5eeaf972f9e4b13fe4fafe0245f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid where

import AST
import CHAD.Types
import Data


unMonoid :: Ex env t -> Ex env t
unMonoid = \case
  EZero t -> zero t
  EPlus t a b -> plus t a b
  EOneHot t i a b -> _ t i a b

  EVar _ t i -> EVar ext t i
  ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
  EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
  EFst _ e -> EFst ext (unMonoid e)
  ESnd _ e -> ESnd ext (unMonoid e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext t (unMonoid e)
  EInr _ t e -> EInr ext t (unMonoid e)
  ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
  ENothing _ t -> ENothing ext t
  EJust _ e -> EJust ext (unMonoid e)
  EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
  EConstArr _ n t x -> EConstArr ext n t x
  EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
  EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
  ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
  EUnit _ e -> EUnit ext (unMonoid e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
  EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
  EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
  EConst _ t x -> EConst ext t x
  EIdx0 _ e -> EIdx0 ext (unMonoid e)
  EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
  EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
  EShape _ e -> EShape ext (unMonoid e)
  EOp _ op e -> EOp ext op (unMonoid e)
  ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
  EWith a b -> EWith (unMonoid a) (unMonoid b)
  EAccum n a b e -> EAccum n (unMonoid a) (unMonoid b) (unMonoid e)
  EError t s -> EError t s

zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
zero (STScal t) = case t of
  STI32 -> ENil ext
  STI64 -> ENil ext
  STF32 -> EConst ext STF32 0.0
  STF64 -> EConst ext STF64 0.0
  STBool -> ENil ext
zero STAccum{} = error "Accumulators not allowed in input program"

plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
plus (STPair t1 t2) a b =
  let t = STPair (d2 t1) (d2 t2)
  in plusSparse t a b $
       EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
                          (EFst ext (EVar ext t IZ)))
                 (plus t2 (ESnd ext (EVar ext t (IS IZ)))
                          (ESnd ext (EVar ext t IZ)))
plus (STEither t1 t2) a b =
  let t = STEither (d2 t1) (d2 t2)
  in plusSparse t a b $
       ECase ext (EVar ext t (IS IZ))
         (ECase ext (EVar ext t (IS IZ))
           (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
           (EError t "plus l+r"))
         (ECase ext (EVar ext t (IS IZ))
           (EError t "plus r+l")
           (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
plus (STMaybe t) a b =
  plusSparse (d2 t) a b $
    plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
plus (STArr n t) a b =
  ELet ext a $
  ELet ext (weakenExpr WSink b) $
    ECase 
plus (STScal t) a b = case t of
  STI32 -> ENil ext
  STI64 -> ENil ext
  STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
  STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
  STBool -> ENil ext
plus STAccum{} _ _ = error "Accumulators not allowed in input program"

plusSparse :: STy a
           -> Ex env (TMaybe a) -> Ex env (TMaybe a)
           -> Ex (a : a : env) a
           -> Ex env (TMaybe a)
plusSparse t a b adder =
  ELet ext b $
    EMaybe ext
      (EVar ext (STMaybe t) IZ)
      (EJust ext
        (EMaybe ext
          (EVar ext t IZ)
          (weakenExpr (WCopy (WCopy WSink)) adder)
          (EVar ext (STMaybe t) (IS IZ))))
      (weakenExpr WSink a)