summaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
blob: 1c53c8a94b73a4399d2e9dfa25ac223338ef2c18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.AST where

import AST
import Data
import Data.Type.Equality
import Language.Tag


data SExpr t where
  -- lambda calculus
  SEVar :: Tag t -> SExpr t
  SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t

  -- base types
  SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b)
  SEFst :: SExpr (TPair a b) -> SExpr a
  SESnd :: SExpr (TPair a b) -> SExpr b
  SENil :: SExpr TNil
  SEInl :: STy b -> SExpr a -> SExpr (TEither a b)
  SEInr :: STy a -> SExpr b -> SExpr (TEither a b)
  SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c

  -- array operations
  SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t)
  SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t)
  SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t)
  SEUnit :: SExpr t -> SExpr (TArr Z t)

  -- expression operations
  SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t)
  SEIdx0 :: SExpr (TArr Z t) -> SExpr t
  SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t)
  SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t
  SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx))
  SEOp :: SOp a t -> SExpr a -> SExpr t

  -- partiality
  SEError :: STy a -> String -> SExpr a
deriving instance Show (SExpr t)

data Lambda a b = Lambda (Tag a) b
  deriving (Show)

mkLambda :: KnownTy a => handle -> (SExpr a -> f t) -> Lambda a (f t)
mkLambda handle f = mkLambda' handle knownTy f

mkLambda' :: handle -> STy a -> (SExpr a -> f t) -> Lambda a (f t)
mkLambda' handle ty f =
  let tag = genTag handle ty
  in Lambda tag (f (SEVar tag))

mkLambda2 :: (KnownTy a, KnownTy b)
          => handle -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
mkLambda2 handle f = mkLambda2' handle knownTy knownTy f

mkLambda2' :: handle -> STy a -> STy b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t))
mkLambda2' handle ty1 ty2 f =
  let tag2 = genTag handle ty2
      lam2 = Lambda tag2 (f (SEVar tag1) (SEVar tag2))
      tag1 = genTag lam2 ty1
  in Lambda tag1 lam2

instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where
  a + b = SEOp (OAdd knownScalTy) (SEPair a b)
  a * b = SEOp (OMul knownScalTy) (SEPair a b)
  negate e = SEOp (ONeg knownScalTy) e
  abs = error "abs undefined for SExpr"
  signum = error "signum undefined for SExpr"
  fromInteger =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> SEConst ty . fromInteger

data SFun args t = SFun (SList Tag args) (SExpr t)

scopeCheck :: SFun env t -> Ex env t
scopeCheck (SFun args e) = scopeCheckExpr args e

scopeCheckExpr :: forall env t. SList Tag env -> SExpr t -> Ex env t
scopeCheckExpr val = \case
  SEVar tag@(Tag ty _)
    | Just idx <- find tag val -> EVar ext ty idx
    | otherwise -> error "Variable out of scope in conversion from surface \
                         \expression to De Bruijn expression"
  SELet a b -> ELet ext (go a) (lambda val b)

  SEPair a b -> EPair ext (go a) (go b)
  SEFst e -> EFst ext (go e)
  SESnd e -> ESnd ext (go e)
  SENil -> ENil ext
  SEInl t e -> EInl ext t (go e)
  SEInr t e -> EInr ext t (go e)
  SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b)

  SEBuild1 a b -> EBuild1 ext (go a) (lambda val b)
  SEBuild n a b -> EBuild ext n (go a) (lambda val b)
  SEFold1 a b -> EFold1 ext (lambda2 val a) (go b)
  SEUnit e -> EUnit ext (go e)

  SEConst t x -> EConst ext t x
  SEIdx0 e -> EIdx0 ext (go e)
  SEIdx1 a b -> EIdx1 ext (go a) (go b)
  SEIdx n a b -> EIdx ext n (go a) (go b)
  SEShape e -> EShape ext (go e)
  SEOp op e -> EOp ext op (go e)

  SEError t s -> EError t s
  where
    go :: SExpr t' -> Ex env t'
    go = scopeCheckExpr val

    find :: Tag t' -> SList Tag env' -> Maybe (Idx env' t')
    find _ SNil = Nothing
    find tag@(Tag ty i) (Tag ty' i' `SCons` val')
      | i == i'
      , Just Refl <- testEquality ty ty'
      = Just IZ
      | otherwise
      = IS <$> find tag val'

    lambda :: SList Tag env' -> Lambda a (SExpr b) -> Ex (a : env') b
    lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e

    lambda2 :: SList Tag env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c
    lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e