aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/UnUser.hs
blob: 73b216c844425931887b337e6524890e2fa0d671 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.UnUser where

import CHAD.AST


type family UnUser t where
  UnUser TNil = TNil
  UnUser (TPair a b) = TPair (UnUser a) (UnUser b)
  UnUser (TEither a b) = TEither (UnUser a) (UnUser b)
  UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b)
  UnUser (TMaybe a) = TMaybe (UnUser a)
  UnUser (TArr n t) = TArr n (UnUser t)
  UnUser (TScal t) = TScal t
  UnUser (TAccum t) = TAccum (UnUser t)
  UnUser (TUser t) = UnUser (UserRep t)

type family UnUserE env where
  UnUserE '[] = '[]
  UnUserE (t : ts) = UnUser t : UnUserE ts

unUserTy :: STy t -> STy (UnUser t)
unUserTy = \case
  STNil -> STNil
  STPair a b -> STPair (unUserTy a) (unUserTy b)
  STEither a b -> STEither (unUserTy a) (unUserTy b)
  STLEither a b -> STLEither (unUserTy a) (unUserTy b)
  STMaybe t -> STMaybe (unUserTy t)
  STArr n t -> STArr n (unUserTy t)
  STScal t -> STScal t
  STAccum t -> STAccum (unUserMTy t)
  STUser t -> unUserTy (userRepTy t)

unUserMTy :: SMTy t -> SMTy (UnUser t)
unUserMTy = \case
  SMTNil -> SMTNil
  SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b)
  SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b)
  SMTMaybe t -> SMTMaybe (unUserMTy t)
  SMTArr n t -> SMTArr n (unUserMTy t)
  SMTScal t -> SMTScal t
  SMTUser t -> unUserMTy (userRepTy t)

unUser :: Ex env t -> Ex (UnUserE env) (UnUser t)
unUser = \case
  EUser _ _ e -> unUser e
  EUnUser _ e -> unUser e

  EVar _ t i -> EVar ext t (goIdx i)
  ELet _ rhs body -> ELet ext (unUser rhs) (unUser body)
  EPair _ a b -> EPair ext (unUser a) (unUser b)
  EFst _ e -> EFst ext (unUser e)
  ESnd _ e -> ESnd ext (unUser e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext (unUserTy t) (unUser e)
  EInr _ t e -> EInr ext (unUserTy t) (unUser e)
  ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b)
  ENothing _ t -> ENothing ext t
  EJust _ e -> EJust ext (unUser e)
  EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e)
  ELNil _ t1 t2 -> ELNil ext t1 t2
  ELInl _ t e -> ELInl ext t (unUser e)
  ELInr _ t e -> ELInr ext t (unUser e)
  ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c)
  EConstArr _ n t x -> EConstArr ext n t x
  EBuild _ n a b -> EBuild ext n (unUser a) (unUser b)
  EMap _ a b -> EMap ext (unUser a) (unUser b)
  EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c)
  ESum1Inner _ e -> ESum1Inner ext (unUser e)
  EUnit _ e -> EUnit ext (unUser e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b)
  EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e)
  EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e)
  EReshape _ n a b -> EReshape ext n (unUser a) (unUser b)
  EZip _ a b -> EZip ext (unUser a) (unUser b)
  EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c)
  EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c)
  EConst _ t x -> EConst ext t x
  EIdx0 _ e -> EIdx0 ext (unUser e)
  EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b)
  EIdx _ a b -> EIdx ext (unUser a) (unUser b)
  EShape _ e -> EShape ext (unUser e)
  EOp _ op e -> EOp ext op (unUser e)
  ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2)
  ERecompute _ e -> ERecompute ext (unUser e)
  EWith _ t a b -> EWith ext t (unUser a) (unUser b)
  EAccum _ t p eidx sp eval eacc ->
    accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
    acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
      EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc))
  EError _ t s -> EError ext t s

  EZero{} -> err_monoid
  EDeepZero{} -> err_monoid
  EPlus{} -> err_monoid
  EOneHot{} -> err_monoid
  where
    err_monoid = error "unUser: Monoid ops found"