{-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} module Compile where import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import AST import Data -- In shape and index arrays, the innermost dimension is on the right (last index). data StructDecl = StructDecl String -- ^ name String -- ^ contents printStructDecl :: StructDecl -> ShowS printStructDecl (StructDecl name contents) = showString "typedef struct { " . showString contents . showString " }" . showString name . showString ";\n" repTy :: STy t -> String repTy (STScal st) = case st of STI32 -> "int32_t" STI64 -> "int64_t" STF32 -> "float" STF64 -> "double" STBool -> "bool" repTy t = genStructName t genStructName :: STy t -> String genStructName = \t -> "ty_" ++ gen t where gen :: STy t -> String gen STNil = "n" gen (STPair a b) = 'p' : gen a ++ gen b gen (STEither a b) = 'e' : gen a ++ gen b gen (STMaybe t) = 'm' : gen t gen (STArr n t) = "A[" ++ show (fromSNat n) ++ "]" ++ gen t gen (STScal st) = case st of STI32 -> "i4" STI64 -> "i8" STF32 -> "f4" STF64 -> "f8" STBool -> "b" gen (STAccum t) = 'C' : gen t genStruct :: STy t -> Map String StructDecl genStruct STNil = Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "") genStruct (STPair a b) = let name = genStructName (STPair a b) in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;")) genStruct (STEither a b) = let name = genStructName (STEither a b) in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };")) genStruct (STMaybe t) = let name = genStructName (STMaybe t) in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;")) genStruct (STArr n t) = let name = genStructName (STArr n t) in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;")) genStruct (STScal _) = mempty genStruct (STAccum t) = let name = genStructName (STAccum t) in Map.singleton name (StructDecl name (genStructName t ++ " a;")) <> genStruct t compile :: Ex env t -> (Map String StructDecl, ()) compile = \case EVar _ _ _ -> mempty ELet _ rhs body -> compile rhs <> compile body EPair _ a b -> genStruct (STPair (typeOf a) (typeOf b)) <> compile a <> compile b EFst _ e -> compile e ESnd _ e -> compile e ENil _ -> mempty EInl _ t e -> genStruct (STEither (typeOf e) t) <> compile e EInr _ t e -> genStruct (STEither t (typeOf e)) <> compile e ECase _ e a b -> compile e <> compile a <> compile b ENothing _ _ -> mempty EJust _ e -> compile e EMaybe _ a b e -> compile a <> compile b <> compile e EConstArr _ n t _ -> genStruct (STArr n (STScal t)) EBuild _ n a b -> genStruct (STArr n (typeOf b)) <> EBuild ext n (compile a) (compile b) EFold1Inner _ a b c -> EFold1Inner ext (compile a) (compile b) (compile c) ESum1Inner _ e -> ESum1Inner ext (compile e) EUnit _ e -> EUnit ext (compile e) EReplicate1Inner _ a b -> EReplicate1Inner ext (compile a) (compile b) EMaximum1Inner _ e -> EMaximum1Inner ext (compile e) EMinimum1Inner _ e -> EMinimum1Inner ext (compile e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (compile e) EIdx1 _ a b -> EIdx1 ext (compile a) (compile b) EIdx _ a b -> EIdx ext (compile a) (compile b) EShape _ e -> EShape ext (compile e) EOp _ op e -> EOp ext op (compile e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (compile a) (compile b) (compile c) (compile e1) (compile e2) EWith a b -> EWith (compile a) (compile b) EAccum n a b e -> EAccum n (compile a) (compile b) (compile e) EZero t -> zero t EPlus t a b -> plus t a b EOneHot t i a b -> onehot t i a b EError t s -> EError t s compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id