blob: b65e643618acba7a53179b3e5e4a02d1717df690 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
module Compile where
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import AST
import Data
-- In shape and index arrays, the innermost dimension is on the right (last index).
data StructDecl = StructDecl
String -- ^ name
String -- ^ contents
printStructDecl :: StructDecl -> ShowS
printStructDecl (StructDecl name contents) =
showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n"
repTy :: STy t -> String
repTy (STScal st) = case st of
STI32 -> "int32_t"
STI64 -> "int64_t"
STF32 -> "float"
STF64 -> "double"
STBool -> "bool"
repTy t = genStructName t
genStructName :: STy t -> String
genStructName = \t -> "ty_" ++ gen t where
gen :: STy t -> String
gen STNil = "n"
gen (STPair a b) = 'p' : gen a ++ gen b
gen (STEither a b) = 'e' : gen a ++ gen b
gen (STMaybe t) = 'm' : gen t
gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t
gen (STScal st) = case st of
STI32 -> "i4"
STI64 -> "i8"
STF32 -> "f4"
STF64 -> "f8"
STBool -> "b"
gen (STAccum t) = 'C' : gen t
genStruct :: STy t -> Map String StructDecl
genStruct STNil =
Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "")
genStruct (STPair a b) =
let name = genStructName (STPair a b)
in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;"))
genStruct (STEither a b) =
let name = genStructName (STEither a b)
in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };"))
genStruct (STMaybe t) =
let name = genStructName (STMaybe t)
in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;"))
genStruct (STArr n t) =
let name = genStructName (STArr n t)
in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;"))
genStruct (STScal _) = mempty
genStruct (STAccum t) =
let name = genStructName (STAccum t)
in Map.singleton name (StructDecl name (genStructName t ++ " a;"))
<> genStruct t
compile :: Ex env t -> (Map String StructDecl, ())
compile = \case
EVar _ _ _ -> mempty
ELet _ rhs body -> compile rhs <> compile body
EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b
EFst _ e -> compile e
ESnd _ e -> compile e
ENil _ -> mempty
EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e
EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e
ECase _ e a b -> compile e <> compile a <> compile b
ENothing _ _ -> mempty
EJust _ e -> compile e
EMaybe _ a b e -> compile a <> compile b <> compile e
EConstArr _ n t _ -> genStruct (STArr n (STScal t))
EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b)
EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c)
ESum1Inner _ e -> ESum1Inner ext (compile e)
EUnit _ e -> EUnit ext (compile e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b)
EMaximum1Inner _ e -> EMaximum1Inner ext (compile e)
EMinimum1Inner _ e -> EMinimum1Inner ext (compile e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (compile e)
EIdx1 _ a b -> EIdx1 ext (compile a) (compile b)
EIdx _ a b -> EIdx ext (compile a) (compile b)
EShape _ e -> EShape ext (compile e)
EOp _ op e -> EOp ext op (compile e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2)
EWith a b -> EWith (compile a) (compile b)
EAccum n a b e -> EAccum n (compile a) (compile b) (compile e)
EZero t -> zero t
EPlus t a b -> plus t a b
EOneHot t i a b -> onehot t i a b
EError t s -> EError t s
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
|