summaryrefslogtreecommitdiff
path: root/src/Compile.hs
blob: b65e643618acba7a53179b3e5e4a02d1717df690 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
module Compile where

import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)

import AST
import Data


-- In shape and index arrays, the innermost dimension is on the right (last index).


data StructDecl = StructDecl
  String  -- ^ name
  String  -- ^ contents

printStructDecl :: StructDecl -> ShowS
printStructDecl (StructDecl name contents) =
  showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n"

repTy :: STy t -> String
repTy (STScal st) = case st of
  STI32 -> "int32_t"
  STI64 -> "int64_t"
  STF32 -> "float"
  STF64 -> "double"
  STBool -> "bool"
repTy t = genStructName t

genStructName :: STy t -> String
genStructName = \t -> "ty_" ++ gen t where
  gen :: STy t -> String
  gen STNil = "n"
  gen (STPair a b) = 'p' : gen a ++ gen b
  gen (STEither a b) = 'e' : gen a ++ gen b
  gen (STMaybe t) = 'm' : gen t
  gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t
  gen (STScal st) = case st of
    STI32 -> "i4"
    STI64 -> "i8"
    STF32 -> "f4"
    STF64 -> "f8"
    STBool -> "b"
  gen (STAccum t) = 'C' : gen t

genStruct :: STy t -> Map String StructDecl
genStruct STNil =
  Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "")
genStruct (STPair a b) =
  let name = genStructName (STPair a b)
  in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;"))
genStruct (STEither a b) =
  let name = genStructName (STEither a b)
  in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };"))
genStruct (STMaybe t) =
  let name = genStructName (STMaybe t)
  in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;"))
genStruct (STArr n t) =
  let name = genStructName (STArr n t)
  in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;"))
genStruct (STScal _) = mempty
genStruct (STAccum t) =
  let name = genStructName (STAccum t)
  in Map.singleton name (StructDecl name (genStructName t ++ " a;"))
       <> genStruct t

compile :: Ex env t -> (Map String StructDecl, ())
compile = \case
  EVar _ _ _ -> mempty
  ELet _ rhs body -> compile rhs <> compile body
  EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b
  EFst _ e -> compile e
  ESnd _ e -> compile e
  ENil _ -> mempty
  EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e
  EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e
  ECase _ e a b -> compile e <> compile a <> compile b
  ENothing _ _ -> mempty
  EJust _ e -> compile e
  EMaybe _ a b e -> compile a <> compile b <> compile e
  EConstArr _ n t _ -> genStruct (STArr n (STScal t))
  EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b)
  EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c)
  ESum1Inner _ e -> ESum1Inner ext (compile e)
  EUnit _ e -> EUnit ext (compile e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b)
  EMaximum1Inner _ e -> EMaximum1Inner ext (compile e)
  EMinimum1Inner _ e -> EMinimum1Inner ext (compile e)
  EConst _ t x -> EConst ext t x
  EIdx0 _ e -> EIdx0 ext (compile e)
  EIdx1 _ a b -> EIdx1 ext (compile a) (compile b)
  EIdx _ a b -> EIdx ext (compile a) (compile b)
  EShape _ e -> EShape ext (compile e)
  EOp _ op e -> EOp ext op (compile e)
  ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2)
  EWith a b -> EWith (compile a) (compile b)
  EAccum n a b e -> EAccum n (compile a) (compile b) (compile e)
  EZero t -> zero t
  EPlus t a b -> plus t a b
  EOneHot t i a b -> onehot t i a b
  EError t s -> EError t s

compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id