summaryrefslogtreecommitdiff
path: root/src/Interpreter/Rep.hs
blob: f84f4e798991a9d47f6db87eb0db289209ac3586 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where

import Data.List (intersperse, intercalate)
import Data.Foldable (toList)
import Data.IORef
import GHC.TypeError

import Array
import AST
import AST.Pretty
import Data


type family Rep t where
  Rep TNil = ()
  Rep (TPair a b) = (Rep a, Rep b)
  Rep (TEither a b) = Either (Rep a) (Rep b)
  Rep (TMaybe t) = Maybe (Rep t)
  Rep (TArr n t) = Array n (Rep t)
  Rep (TScal sty) = ScalRep sty
  Rep (TAccum t) = RepAc t

-- Mutable, represents D2 of t. Has an O(1) zero.
type family RepAc t where
  RepAc TNil = ()
  RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b))
  RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))
  RepAc (TMaybe t) = IORef (Maybe (RepAc t))
  -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero.
  RepAc (TArr n t) = IORef (Array n (RepAc t))  -- empty array is zero
  RepAc (TScal sty) = RepAcScal sty
  RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")

type family RepAcScal t where
  RepAcScal TI32 = ()
  RepAcScal TI64 = ()
  RepAcScal TF32 = IORef Float
  RepAcScal TF64 = IORef Double
  RepAcScal TBool = ()

newtype Value t = Value { unValue :: Rep t }

liftV :: (Rep a -> Rep b) -> Value a -> Value b
liftV f (Value x) = Value (f x)

liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
liftV2 f (Value x) (Value y) = Value (f x y)

vPair :: Value a -> Value b -> Value (TPair a b)
vPair = liftV2 (,)

vUnpair :: Value (TPair a b) -> (Value a, Value b)
vUnpair (Value (x, y)) = (Value x, Value y)

showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
showValue d (STArr _ t) arr = showParen (d > 10) $
  showString "arrayFromList " . showsPrec 11 (arrayShape arr)
  . showString " ["
  . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
  . showString "]"
showValue _ (STScal sty) x = case sty of
  STF32 -> shows x
  STF64 -> shows x
  STI32 -> shows x
  STI64 -> shows x
  STBool -> shows x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSTy 0 t ++ ">"

showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
  where
    showEntries :: SList STy env -> SList Value env -> [String]
    showEntries SNil SNil = []
    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs