summaryrefslogtreecommitdiff
path: root/src/AST/Types.hs
blob: 0b41671413d71e04079b95f621a2f13a7c77e796 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module AST.Types where

import Data.Int (Int32, Int64)
import Data.Kind (Type)
import Data.Type.Equality

import Data


data Ty
  = TNil
  | TPair Ty Ty
  | TEither Ty Ty
  | TMaybe Ty
  | TArr Nat Ty  -- ^ rank, element type
  | TScal ScalTy
  | TAccum Ty  -- ^ the accumulator contains D2 of this type
  deriving (Show, Eq, Ord)

data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
  deriving (Show, Eq, Ord)

type STy :: Ty -> Type
data STy t where
  STNil :: STy TNil
  STPair :: STy a -> STy b -> STy (TPair a b)
  STEither :: STy a -> STy b -> STy (TEither a b)
  STMaybe :: STy a -> STy (TMaybe a)
  STArr :: SNat n -> STy t -> STy (TArr n t)
  STScal :: SScalTy t -> STy (TScal t)
  STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)

instance TestEquality STy where
  testEquality STNil STNil = Just Refl
  testEquality STNil _ = Nothing
  testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STPair{} _ = Nothing
  testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STEither{} _ = Nothing
  testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STMaybe{} _ = Nothing
  testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STArr{} _ = Nothing
  testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STScal{} _ = Nothing
  testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STAccum{} _ = Nothing

data SScalTy t where
  STI32 :: SScalTy TI32
  STI64 :: SScalTy TI64
  STF32 :: SScalTy TF32
  STF64 :: SScalTy TF64
  STBool :: SScalTy TBool
deriving instance Show (SScalTy t)

instance TestEquality SScalTy where
  testEquality STI32 STI32 = Just Refl
  testEquality STI64 STI64 = Just Refl
  testEquality STF32 STF32 = Just Refl
  testEquality STF64 STF64 = Just Refl
  testEquality STBool STBool = Just Refl
  testEquality _ _ = Nothing

scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
scalRepIsShow STI32 = Dict
scalRepIsShow STI64 = Dict
scalRepIsShow STF32 = Dict
scalRepIsShow STF64 = Dict
scalRepIsShow STBool = Dict

type TIx = TScal TI64

tIx :: STy TIx
tIx = STScal STI64

type family ScalRep t where
  ScalRep TI32 = Int32
  ScalRep TI64 = Int64
  ScalRep TF32 = Float
  ScalRep TF64 = Double
  ScalRep TBool = Bool

type family ScalIsNumeric t where
  ScalIsNumeric TI32 = True
  ScalIsNumeric TI64 = True
  ScalIsNumeric TF32 = True
  ScalIsNumeric TF64 = True
  ScalIsNumeric TBool = False

type family ScalIsFloating t where
  ScalIsFloating TI32 = False
  ScalIsFloating TI64 = False
  ScalIsFloating TF32 = True
  ScalIsFloating TF64 = True
  ScalIsFloating TBool = False

type family ScalIsIntegral t where
  ScalIsIntegral TI32 = True
  ScalIsIntegral TI64 = True
  ScalIsIntegral TF32 = False
  ScalIsIntegral TF64 = False
  ScalIsIntegral TBool = False

-- | Returns true for arrays /and/ accumulators;
hasArrays :: STy t' -> Bool
hasArrays STNil = False
hasArrays (STPair a b) = hasArrays a || hasArrays b
hasArrays (STEither a b) = hasArrays a || hasArrays b
hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = True

type family Tup env where
  Tup '[] = TNil
  Tup (t : ts) = TPair (Tup ts) t

mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
      -> SList f list -> f (Tup list)
mkTup nil _    SNil = nil
mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e

tTup :: SList STy env -> STy (Tup env)
tTup = mkTup STNil STPair

unTup :: (forall a b. c (TPair a b) -> (c a, c b))
      -> SList f list -> c (Tup list) -> SList c list
unTup _ SNil _ = SNil
unTup unpack (_ `SCons` list) tup =
  let (xs, x) = unpack tup
  in x `SCons` unTup unpack list xs

type family InvTup core env where
  InvTup core '[] = core
  InvTup core (t : ts) = InvTup (TPair core t) ts