blob: 988a450b4fdedef4636bc14e080995f62f42c29f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
import AST.Types
import Data
data AcPrj
= APHere
| APFst AcPrj
| APSnd AcPrj
| APLeft AcPrj
| APRight AcPrj
| APJust AcPrj
| APArrIdx AcPrj
| APArrSlice Nat
-- | @b@ is a small part of @a@, indicated by the projection @p@.
data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPHere :: SAcPrj APHere a a
SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
-- TODO:
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
type data AIDense = AID | AIS
data SAIDense d where
SAID :: SAIDense AID
SAIS :: SAIDense AIS
deriving instance Show (SAIDense d)
type family AcIdx d p t where
AcIdx d APHere t = TNil
AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
AcIdx AID (APArrIdx p) (TArr n a) =
-- (index, recursive info)
TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
AcIdx AIS (APArrIdx p) (TArr n a) =
-- ((index, shape info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
(AcIdx AIS p a)
-- AcIdx AID (APArrSlice m) (TArr n a) =
-- -- index
-- Tup (Replicate m TIx)
-- AcIdx AIS (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
type AcIdxD p t = AcIdx AID p t
type AcIdxS p t = AcIdx AIS p t
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
type family ZeroInfo t where
ZeroInfo TNil = TNil
ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
ZeroInfo (TLEither a b) = TNil
ZeroInfo (TMaybe a) = TNil
ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
ZeroInfo (TScal t) = TNil
tZeroInfo :: SMTy t -> STy (ZeroInfo t)
tZeroInfo SMTNil = STNil
tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
tZeroInfo (SMTLEither _ _) = STNil
tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
-- | Info needed to create a zero-valued deep accumulator for a monoid type.
-- Should be constructable from a D1.
type family DeepZeroInfo t where
DeepZeroInfo TNil = TNil
DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
DeepZeroInfo (TScal t) = TNil
tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
tDeepZeroInfo SMTNil = STNil
tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
tDeepZeroInfo (SMTScal _) = STNil
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
-- AccumInfo TNil = TNil
-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
-- AccumInfo (TArr n t) = TArr n (AccumInfo t)
-- AccumInfo (TScal t) = TNil
-- type family PrimalInfo t where
-- PrimalInfo TNil = TNil
-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
-- PrimalInfo (TScal t) = TNil
-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
-- tPrimalInfo SMTNil = STNil
-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
-- tPrimalInfo (SMTScal _) = STNil
|