summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
blob: 7ffb14cd2f568947a7d631e96691fe615ebef246 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Interpreter (
  interpret,
  interpret',
  Value,
  NoAccum(..),
  unAccum,
) where

import Data.Int (Int64)
import Data.Proxy

import AST
import Data
import Array
import Interpreter.Accum
import Interpreter.Rep
import Control.Monad (foldM)


interpret :: NoAccum t => Ex '[] t -> Rep t
interpret e = runAcM (go e)
  where
    go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t)
    go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e'

newtype Value s t = Value (Rep' s t)

interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t)
interpret' env = \case
  EVar _ _ i -> case slistIdx env i of Value x -> return x
  ELet _ a b -> do
    x <- interpret' env a
    interpret' (Value x `SCons` env) b
  EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
  EFst _ e -> fst <$> interpret' env e
  ESnd _ e -> snd <$> interpret' env e
  ENil _ -> return ()
  EInl _ _ e -> Left <$> interpret' env e
  EInr _ _ e -> Right <$> interpret' env e
  ECase _ e a b -> interpret' env e >>= \case
                     Left x -> interpret' (Value x `SCons` env) a
                     Right y -> interpret' (Value y `SCons` env) b
  EConstArr _ _ _ v -> return v
  EBuild1 _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arrayGenerateLinM (ShNil `ShCons` n)
                      (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
  EBuild _ dim a b -> do
    sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a
    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b)
  EFold1Inner _ a b -> do
    let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
    arr <- interpret' env b
    let sh `ShCons` n = arrayShape arr
    arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  ESum1Inner _ e -> do
    arr <- interpret' env e
    let STArr _ (STScal t) = typeOf e
        sh `ShCons` n = arrayShape arr
    numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
  EReplicate1Inner _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arr <- interpret' env b
    let sh = arrayShape arr
    arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
  EConst _ _ v -> return v
  EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
  EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b)
  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e
  EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e
  EWith e1 e2 -> do
    initval <- interpret' env e1
    withAccum (typeOf e1) initval $ \accum ->
      interpret' (Value accum `SCons` env) e2
  EAccum i e1 e2 e3 -> do
    idx <- interpret' env e1
    val <- interpret' env e2
    accum <- interpret' env e3
    accumAdd accum i idx val
  EError _ s -> error $ "Interpreter: Program threw error: " ++ s

interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t
interpretOp _ op arg = case op of
  OAdd st -> numericIsNum st $ uncurry (+) arg
  OMul st -> numericIsNum st $ uncurry (*) arg
  ONeg st -> numericIsNum st $ negate arg
  OLt st -> numericIsNum st $ uncurry (<) arg
  OLe st -> numericIsNum st $ uncurry (<=) arg
  OEq st -> numericIsNum st $ uncurry (==) arg
  ONot -> not arg
  OIf -> if arg then Left () else Right ()

numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id

unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m))
            -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n
unTupRepIdx _ nil _    SZ _ = nil
unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i

tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int))
          -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx))
tupRepIdx _ _      SZ _ = ()
tupRepIdx p uncons (SS n) tup =
  let (tup', i) = uncons tup
  in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i)

ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)

shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)

class NoAccum t where
  noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t
instance NoAccum TNil where
  noAccum _ _ = Refl
instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where
  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where
  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
instance NoAccum t => NoAccum (TArr n t) where
  noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl
instance NoAccum (TScal t) where
  noAccum _ _ = Refl

unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t))
unAccum _ STNil = Just Dict
unAccum p (STPair t1 t2)
  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
  | otherwise = Nothing
unAccum p (STEither t1 t2)
  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
  | otherwise = Nothing
unAccum p (STArr _ t)
  | Just Dict <- unAccum p t = Just Dict
  | otherwise = Nothing
unAccum _ STScal{} = Just Dict
unAccum _ STAccum{} = Nothing

foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail