summaryrefslogtreecommitdiff
path: root/src/AST/Types.hs
blob: acf705339868f7bacbdddeb361d0084559e0d59f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
module AST.Types where

import Data.Int (Int32, Int64)
import Data.Kind (Type)
import Data.Type.Equality

import Data


data Ty
  = TNil
  | TPair Ty Ty
  | TEither Ty Ty
  | TMaybe Ty
  | TArr Nat Ty  -- ^ rank, element type
  | TScal ScalTy
  | TAccum Ty
  deriving (Show, Eq, Ord)

data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
  deriving (Show, Eq, Ord)

type STy :: Ty -> Type
data STy t where
  STNil :: STy TNil
  STPair :: STy a -> STy b -> STy (TPair a b)
  STEither :: STy a -> STy b -> STy (TEither a b)
  STMaybe :: STy a -> STy (TMaybe a)
  STArr :: SNat n -> STy t -> STy (TArr n t)
  STScal :: SScalTy t -> STy (TScal t)
  STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)

instance TestEquality STy where
  testEquality STNil STNil = Just Refl
  testEquality STNil _ = Nothing
  testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STPair{} _ = Nothing
  testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STEither{} _ = Nothing
  testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STMaybe{} _ = Nothing
  testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STArr{} _ = Nothing
  testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STScal{} _ = Nothing
  testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STAccum{} _ = Nothing

data SScalTy t where
  STI32 :: SScalTy TI32
  STI64 :: SScalTy TI64
  STF32 :: SScalTy TF32
  STF64 :: SScalTy TF64
  STBool :: SScalTy TBool
deriving instance Show (SScalTy t)

instance TestEquality SScalTy where
  testEquality STI32 STI32 = Just Refl
  testEquality STI64 STI64 = Just Refl
  testEquality STF32 STF32 = Just Refl
  testEquality STF64 STF64 = Just Refl
  testEquality STBool STBool = Just Refl
  testEquality _ _ = Nothing

scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
scalRepIsShow STI32 = Dict
scalRepIsShow STI64 = Dict
scalRepIsShow STF32 = Dict
scalRepIsShow STF64 = Dict
scalRepIsShow STBool = Dict

type TIx = TScal TI64

tIx :: STy TIx
tIx = STScal STI64

type family ScalRep t where
  ScalRep TI32 = Int32
  ScalRep TI64 = Int64
  ScalRep TF32 = Float
  ScalRep TF64 = Double
  ScalRep TBool = Bool

type family ScalIsNumeric t where
  ScalIsNumeric TI32 = True
  ScalIsNumeric TI64 = True
  ScalIsNumeric TF32 = True
  ScalIsNumeric TF64 = True
  ScalIsNumeric TBool = False

type family ScalIsFloating t where
  ScalIsFloating TI32 = False
  ScalIsFloating TI64 = False
  ScalIsFloating TF32 = True
  ScalIsFloating TF64 = True
  ScalIsFloating TBool = False

type family ScalIsIntegral t where
  ScalIsIntegral TI32 = True
  ScalIsIntegral TI64 = True
  ScalIsIntegral TF32 = False
  ScalIsIntegral TF64 = False
  ScalIsIntegral TBool = False

-- | Returns true for arrays /and/ accumulators;
hasArrays :: STy t' -> Bool
hasArrays STNil = False
hasArrays (STPair a b) = hasArrays a || hasArrays b
hasArrays (STEither a b) = hasArrays a || hasArrays b
hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = True