summaryrefslogtreecommitdiff
path: root/src/AST/Accum.hs
blob: 619c2b16fa02b26a7d185873216c77b32978945f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where

import AST.Types
import Data


data AcPrj
  = APHere
  | APFst AcPrj
  | APSnd AcPrj
  | APLeft AcPrj
  | APRight AcPrj
  | APJust AcPrj
  | APArrIdx AcPrj
  | APArrSlice Nat

-- | @b@ is a small part of @a@, indicated by the projection @p@.
data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
  SAPHere :: SAcPrj APHere a a
  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
  SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
  SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
  -- TODO:
  -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)

type data AIDense = AID | AIS

data SAIDense d where
  SAID :: SAIDense AID
  SAIS :: SAIDense AIS
deriving instance Show (SAIDense d)

type family AcIdx d p t where
  AcIdx d APHere t = TNil
  AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
  AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
  AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
  AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
  AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
  AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
  AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
  AcIdx AID (APArrIdx p) (TArr n a) =
    -- (index, recursive info)
    TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
  AcIdx AIS (APArrIdx p) (TArr n a) =
    -- ((index, shape info), recursive info)
    TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
          (AcIdx AIS p a)
  -- AcIdx AID (APArrSlice m) (TArr n a) =
  --   -- index
  --   Tup (Replicate m TIx)
  -- AcIdx AIS (APArrSlice m) (TArr n a) =
  --   -- (index, array shape)
  --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))

type AcIdxD p t = AcIdx AID p t
type AcIdxS p t = AcIdx AIS p t

acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t

type family ZeroInfo t where
  ZeroInfo TNil = TNil
  ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
  ZeroInfo (TLEither a b) = TNil
  ZeroInfo (TMaybe a) = TNil
  ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
  ZeroInfo (TScal t) = TNil

tZeroInfo :: SMTy t -> STy (ZeroInfo t)
tZeroInfo SMTNil = STNil
tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
tZeroInfo (SMTLEither _ _) = STNil
tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil

-- | Info needed to create a zero-valued deep accumulator for a monoid type.
-- Should be constructable from a D1.
type family DeepZeroInfo t where
  DeepZeroInfo TNil = TNil
  DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
  DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
  DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
  DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
  DeepZeroInfo (TScal t) = TNil

-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
--   AccumInfo TNil = TNil
--   AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
--   AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
--   AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
--   AccumInfo (TArr n t) = TArr n (AccumInfo t)
--   AccumInfo (TScal t) = TNil

-- type family PrimalInfo t where
--   PrimalInfo TNil = TNil
--   PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
--   PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
--   PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
--   PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
--   PrimalInfo (TScal t) = TNil

-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
-- tPrimalInfo SMTNil = STNil
-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
-- tPrimalInfo (SMTScal _) = STNil