summaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
blob: f31f2499cae5398f9e9a616dbb0a2443e61d18e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeApplications #-}
module Language.AST where

import Data.Proxy
import Data.Type.Equality
import GHC.OverloadedLabels
import GHC.TypeLits (symbolVal, KnownSymbol)

import AST
import Data


data SExpr t where
  -- lambda calculus
  SEVar :: Var t -> SExpr t
  SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t

  -- base types
  SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b)
  SEFst :: SExpr (TPair a b) -> SExpr a
  SESnd :: SExpr (TPair a b) -> SExpr b
  SENil :: SExpr TNil
  SEInl :: STy b -> SExpr a -> SExpr (TEither a b)
  SEInr :: STy a -> SExpr b -> SExpr (TEither a b)
  SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c

  -- array operations
  SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t)
  SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t)
  SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t)
  SEUnit :: SExpr t -> SExpr (TArr Z t)

  -- expression operations
  SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t)
  SEIdx0 :: SExpr (TArr Z t) -> SExpr t
  SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t)
  SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t
  SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx))
  SEOp :: SOp a t -> SExpr a -> SExpr t

  -- partiality
  SEError :: STy a -> String -> SExpr a
deriving instance Show (SExpr t)

data Var a = Var (STy a) String
  deriving (Show)

data Lambda a b = Lambda (Var a) b
  deriving (Show)

mkLambda :: KnownTy a => String -> (SExpr a -> f t) -> Lambda a (f t)
mkLambda name f = mkLambda' (Var knownTy name) f

mkLambda' :: Var a -> (SExpr a -> f t) -> Lambda a (f t)
mkLambda' var f = Lambda var (f (SEVar var))

mkLambda2 :: (KnownTy a, KnownTy b)
          => String -> String -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
mkLambda2 name1 name2 f = mkLambda2' (Var knownTy name1) (Var knownTy name2) f

mkLambda2' :: Var a -> Var b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
mkLambda2' var1 var2 f = Lambda var1 (Lambda var2 (f (SEVar var1) (SEVar var2)))

instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where
  a + b = SEOp (OAdd knownScalTy) (SEPair a b)
  a * b = SEOp (OMul knownScalTy) (SEPair a b)
  negate e = SEOp (ONeg knownScalTy) e
  abs = error "abs undefined for SExpr"
  signum = error "signum undefined for SExpr"
  fromInteger =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> SEConst ty . fromInteger

instance (KnownTy t, KnownSymbol name) => IsLabel name (Var t) where
  fromLabel = Var knownTy (symbolVal (Proxy @name))

instance (KnownTy t, KnownSymbol name) => IsLabel name (SExpr t) where
  fromLabel = SEVar (fromLabel @name)

data SFun args t = SFun (SList Var args) (SExpr t)

scopeCheck :: SFun env t -> Ex env t
scopeCheck (SFun args e) = scopeCheckExpr args e

scopeCheckExpr :: forall env t. SList Var env -> SExpr t -> Ex env t
scopeCheckExpr val = \case
  SEVar tag@(Var ty _)
    | Just idx <- find tag val -> EVar ext ty idx
    | otherwise -> error "Variable out of scope in conversion from surface \
                         \expression to De Bruijn expression"
  SELet a b -> ELet ext (go a) (lambda val b)

  SEPair a b -> EPair ext (go a) (go b)
  SEFst e -> EFst ext (go e)
  SESnd e -> ESnd ext (go e)
  SENil -> ENil ext
  SEInl t e -> EInl ext t (go e)
  SEInr t e -> EInr ext t (go e)
  SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b)

  SEBuild1 a b -> EBuild1 ext (go a) (lambda val b)
  SEBuild n a b -> EBuild ext n (go a) (lambda val b)
  SEFold1 a b -> EFold1 ext (lambda2 val a) (go b)
  SEUnit e -> EUnit ext (go e)

  SEConst t x -> EConst ext t x
  SEIdx0 e -> EIdx0 ext (go e)
  SEIdx1 a b -> EIdx1 ext (go a) (go b)
  SEIdx n a b -> EIdx ext n (go a) (go b)
  SEShape e -> EShape ext (go e)
  SEOp op e -> EOp ext op (go e)

  SEError t s -> EError t s
  where
    go :: SExpr t' -> Ex env t'
    go = scopeCheckExpr val

    find :: Var t' -> SList Var env' -> Maybe (Idx env' t')
    find _ SNil = Nothing
    find tag@(Var ty s) (Var ty' s' `SCons` val')
      | s == s'
      , Just Refl <- testEquality ty ty'
      = Just IZ
      | otherwise
      = IS <$> find tag val'

    lambda :: SList Var env' -> Lambda a (SExpr b) -> Ex (a : env') b
    lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e

    lambda2 :: SList Var env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c
    lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e