summaryrefslogtreecommitdiff
path: root/interpreter.hs
blob: 45950354cbfda0f41d036d76d6a48d62848e0e57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
{-# LANGUAGE TupleSections, DeriveFunctor, GeneralizedNewtypeDeriving #-}
module Interpreter(newContext, interpret, interpretProgram, Context) where

import Control.Applicative
import Control.Monad
import Control.Monad.Except
import Control.Monad.State
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
-- import Debug.Trace

import AST
import Parser


-- TODO: place bottom varmap separately for performance? (global define's in deeply nested contexts)
type VarMap = Map.Map Name Value
data Context = Context {varMapStk :: [VarMap]}

newtype IM a = IM {unIM :: StateT Context (ExceptT String IO) a}
  deriving (Functor, Applicative, Monad, MonadError String, MonadState Context, MonadIO)

type Builtin = [Value] -> IM Value

newContext :: Context
newContext = Context [Map.fromList [(k, VBuiltin k) | k <- Map.keys builtins]]

lookupVar :: Name -> IM (Maybe Value)
lookupVar name = liftM (msum . map (Map.lookup name)) (gets varMapStk)

withScopeMap :: VarMap -> IM a -> IM a
withScopeMap vm act = do
    modify $ \ctx -> ctx {varMapStk = vm : varMapStk ctx}
    x <- act
    modify $ \ctx -> ctx {varMapStk = tail (varMapStk ctx)}
    return x

builtins :: Map.Map String Builtin
builtins = Map.fromList
    [("read", readBuiltin),
     ("eval", evalBuiltin),
     ("print", printBuiltin),
     ("getline", getlineBuiltin),
     ("loop", loopBuiltin),
     ("do", doBuiltin),
     ("if", ifBuiltin),
     ("define", defineBuiltin),
     ("lambda", lambdaBuiltin),
     ("match", matchBuiltin),
     ("+", plusBuiltin),
     ("-", arithBuiltin "-" (-) 0), ("*", arithBuiltin "*" (*) 1),
     ("/", arithBuiltin "/" div 1), ("%", arithBuiltin "%" mod 1),
     ("<", compareBuiltin "<" (<) (<)), ("<=", compareBuiltin "<=" (<=) (<=)),
     (">", compareBuiltin ">" (>) (>)), (">=", compareBuiltin ">=" (>=) (>=)),
     ("=", compareBuiltin "=" (==) (==)), ("!=", neqBuiltin)]

nArguments :: String -> Int -> Bool -> Builtin -> Builtin
nArguments name n doeval f args
    | length args /= n =
            throwError $ "Function '" ++ name ++ "' expects " ++ naStr n ++ " but got " ++ naStr (length args)
    | doeval = mapM evalValue args >>= f
    | otherwise = f args
  where
    naStr 0 = "no arguments"
    naStr 1 = "1 argument"
    naStr num = show num ++ " arguments"


readBuiltin :: Builtin
readBuiltin = nArguments "read" 1 True go
  where
    go :: Builtin
    go [VString s] = either (throwError . show) return $ parseExpression s
    go _ = throwError "Can only 'read' a string"

evalBuiltin :: Builtin
evalBuiltin = nArguments "eval" 1 True (evalValue . head)

printBuiltin :: Builtin
printBuiltin args = do
    args' <- mapM evalValue args
    liftIO (putStrLn $ intercalate " " $ map printShow args') >> return (VList [])
  where
    printShow :: Value -> String
    printShow (VString s) = s
    printShow v = show v

loopBuiltin :: Builtin
loopBuiltin = nArguments "loop" 1 False $ forever . evalValue . head

getlineBuiltin :: Builtin
getlineBuiltin = nArguments "getline" 0 True $ const (liftIO getLine >>= return . VString)

doBuiltin :: Builtin
doBuiltin [] = return (VList [])
doBuiltin args = fmap last $ mapM evalValue args

ifBuiltin :: Builtin
ifBuiltin [cond, v1] = evalValue cond >>= \c -> if truthy c then evalValue v1 else return (VList [])
ifBuiltin [cond, v1, v2] = evalValue cond >>= \c -> if truthy c then evalValue v1 else evalValue v2
ifBuiltin a = throwError $ "Cannot pass " ++ show (length a) ++ " arguments to 'if'"

defineBuiltin :: Builtin
defineBuiltin [VName name, val] = do
    val' <- evalValue val
    stk <- gets varMapStk
    let go [vm] _ = [Map.insert name val' vm]
        go (vm : vms) (False : prs) = vm : go vms prs
        go (vm : vms) (True : _) = Map.insert name val' vm : vms
        go _ _ = undefined
    modify $ \ctx -> ctx {varMapStk = go stk (map (isJust . Map.lookup name) stk)}
    return (VList [])
defineBuiltin [name@(VName _), VList args, val]
    | Just names <- mapM fromVName args = defineBuiltin [name, VLambda names val]
    | otherwise = throwError "Invalid 'define' syntax: invalid argument list"
defineBuiltin _ = throwError "Invalid 'define' syntax"

lambdaBuiltin :: Builtin
lambdaBuiltin = nArguments "lambda" 2 False go
  where
    go :: Builtin
    go [VList args, body]
        | Just names <- mapM fromVName args = return (VLambda names body)
        | otherwise = throwError "Invalid 'lambda' syntax: invalid argument list"
    go _ = throwError "Invalid 'lambda' syntax"

matchBuiltin :: Builtin
matchBuiltin [] = throwError "Invalid 'match' syntax: empty match"
matchBuiltin [_] = throwError "Invalid 'match' syntax: no arms"
matchBuiltin (subject : arms) = do
    subject' <- evalValue subject
    go subject' arms
  where
    go :: Value -> [Value] -> IM Value
    go _ [def] = evalValue def
    go subject' (VList [pat, value] : rest) =
        case match pat subject' Map.empty of
            Nothing -> go subject' rest
            Just mp -> withScopeMap mp (evalValue value)
    go _ _ = throwError "Invalid 'match' syntax: invalid arm"

plusBuiltin :: Builtin
plusBuiltin [] = return (VNum 0)
plusBuiltin args
    | Just nums <- mapM fromVNum args = return (VNum (sum nums))
    | Just strs <- mapM maybeStrings args = return (VString (concat strs))
    | otherwise = throwError "Invalid argument types to operator '+'"

arithBuiltin :: String -> (Int -> Int -> Int) -> Int -> Builtin
arithBuiltin name oper idelem args = do
    args' <- mapM evalValue args
    case mapM fromVNum args' of
        Just [] -> return (VNum idelem)
        Just (hd : tl) -> return (VNum (foldl oper hd tl))
        _ -> throwError $ "Invalid argument types to operator '" ++ name ++ "'"

neqBuiltin :: Builtin
neqBuiltin = fmap (\(VNum x) -> VNum (1 - x)) . compareBuiltin "!=" (==) (==)

compareBuiltin :: String -> (Int -> Int -> Bool) -> (String -> String -> Bool) -> Builtin
compareBuiltin name oper soper args = do
    args' <- mapM evalValue args
    res <- case () of
        _ | Just nums <- mapM fromVNum args' -> return $ all (uncurry oper) (zip nums (tail nums))
          | Just strs <- mapM maybeStrings args' -> return $ all (uncurry soper) (zip strs (tail strs))
          | otherwise -> throwError $ "Invalid argument types to operator '" ++ name ++ "'"
    return $ VNum $ fromIntegral $ fromEnum res


truthy :: Value -> Bool
truthy (VNum n) = n /= 0
truthy _ = True

match :: Value -> Value -> VarMap -> Maybe VarMap
match (VList []) (VList []) mp = Just mp
match (VList [VEllipsis]) (VList _) mp = Just mp
match (VList (pat : pats)) (VList (val : vals)) mp = match pat val mp >>= match (VList pats) (VList vals)
match (VName name) val mp = case Map.lookup name mp of
    Nothing -> Just (Map.insert name val mp)
    Just val' | val == val' -> Just mp
              | otherwise -> Nothing
match (VQuoted a) (VQuoted b) mp = match a b mp
match (VLambda _ _) _ _ = Nothing
match a b mp | a == b = Just mp
             | otherwise = Nothing

maybeStrings :: Value -> Maybe String
maybeStrings = liftM2 (<|>) fromVString (fmap show . fromVNum)


evalValue :: Value -> IM Value
-- evalValue v | traceShow v False = undefined
evalValue (VList exs) = listCall exs
evalValue e@(VNum _) = return e
evalValue e@(VString _) = return e
evalValue (VName name) = lookupVar name >>= \mval -> case mval of
    Just value -> return value
    Nothing -> throwError $ "Use of undefined variable '" ++ name ++ "'"
evalValue (VQuoted e) = return e
evalValue e@(VLambda _ _) = return e
evalValue e@(VBuiltin _) = return e
evalValue VEllipsis = throwError "Unexpected ellipsis in code"

listCall :: [Value] -> IM Value
listCall [] = throwError "Cannot call ()"
listCall (hd : args) = evalValue hd >>= \hd' -> case hd' of
    VLambda names body
        | length names == length args -> do
            args' <- mapM evalValue args
            withScopeMap (Map.fromList (zip names args')) (evalValue body)
        | otherwise -> throwError $ "Invalid number of arguments in call to lambda " ++
                                    "(" ++ show (length args) ++ " found, " ++ show (length names) ++ " needed)"
    VBuiltin name ->
        case Map.lookup name builtins of
            Just f -> f args
            Nothing -> throwError $ "Unknown builtin '" ++ name ++ "'"
    v -> throwError $ "Cannot call value: " ++ show v


interpret :: Context -> Value -> IO (Either String (Value, Context))
interpret ctx val =
    runExceptT $ flip runStateT ctx $ unIM $ evalValue val

interpretProgram :: Context -> Program -> IO (Either String Context)
interpretProgram rootctx (Program l) = go l rootctx
  where
    go :: [Value] -> Context -> IO (Either String Context)
    go [] ctx = return (Right ctx)
    go (val : vals) ctx = do
        e <- interpret ctx val
        case e of
            Left err -> return (Left err)
            Right (_, ctx') -> go vals ctx'