blob: 73b216c844425931887b337e6524890e2fa0d671 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.UnUser where
import CHAD.AST
type family UnUser t where
UnUser TNil = TNil
UnUser (TPair a b) = TPair (UnUser a) (UnUser b)
UnUser (TEither a b) = TEither (UnUser a) (UnUser b)
UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b)
UnUser (TMaybe a) = TMaybe (UnUser a)
UnUser (TArr n t) = TArr n (UnUser t)
UnUser (TScal t) = TScal t
UnUser (TAccum t) = TAccum (UnUser t)
UnUser (TUser t) = UnUser (UserRep t)
type family UnUserE env where
UnUserE '[] = '[]
UnUserE (t : ts) = UnUser t : UnUserE ts
unUserTy :: STy t -> STy (UnUser t)
unUserTy = \case
STNil -> STNil
STPair a b -> STPair (unUserTy a) (unUserTy b)
STEither a b -> STEither (unUserTy a) (unUserTy b)
STLEither a b -> STLEither (unUserTy a) (unUserTy b)
STMaybe t -> STMaybe (unUserTy t)
STArr n t -> STArr n (unUserTy t)
STScal t -> STScal t
STAccum t -> STAccum (unUserMTy t)
STUser t -> unUserTy (userRepTy t)
unUserMTy :: SMTy t -> SMTy (UnUser t)
unUserMTy = \case
SMTNil -> SMTNil
SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b)
SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b)
SMTMaybe t -> SMTMaybe (unUserMTy t)
SMTArr n t -> SMTArr n (unUserMTy t)
SMTScal t -> SMTScal t
SMTUser t -> unUserMTy (userRepTy t)
unUser :: Ex env t -> Ex (UnUserE env) (UnUser t)
unUser = \case
EUser _ _ e -> unUser e
EUnUser _ e -> unUser e
EVar _ t i -> EVar ext t (goIdx i)
ELet _ rhs body -> ELet ext (unUser rhs) (unUser body)
EPair _ a b -> EPair ext (unUser a) (unUser b)
EFst _ e -> EFst ext (unUser e)
ESnd _ e -> ESnd ext (unUser e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext (unUserTy t) (unUser e)
EInr _ t e -> EInr ext (unUserTy t) (unUser e)
ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b)
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unUser e)
EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e)
ELNil _ t1 t2 -> ELNil ext t1 t2
ELInl _ t e -> ELInl ext t (unUser e)
ELInr _ t e -> ELInr ext t (unUser e)
ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unUser a) (unUser b)
EMap _ a b -> EMap ext (unUser a) (unUser b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c)
ESum1Inner _ e -> ESum1Inner ext (unUser e)
EUnit _ e -> EUnit ext (unUser e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e)
EReshape _ n a b -> EReshape ext n (unUser a) (unUser b)
EZip _ a b -> EZip ext (unUser a) (unUser b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c)
EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unUser e)
EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b)
EIdx _ a b -> EIdx ext (unUser a) (unUser b)
EShape _ e -> EShape ext (unUser e)
EOp _ op e -> EOp ext op (unUser e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2)
ERecompute _ e -> ERecompute ext (unUser e)
EWith _ t a b -> EWith ext t (unUser a) (unUser b)
EAccum _ t p eidx sp eval eacc ->
accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc))
EError _ t s -> EError ext t s
EZero{} -> err_monoid
EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
where
err_monoid = error "unUser: Monoid ops found"
|