diff options
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 76 |
1 files changed, 38 insertions, 38 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 6ba3a39..9ed5a27 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -77,7 +77,7 @@ compile = \env expr -> do let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source ["kernel"] + lib <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type @@ -86,7 +86,7 @@ compile = \env expr -> do allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) serialiseArguments args ptr $ do - callKernelFun "kernel" lib ptr + callKernelFun lib ptr ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) when (ok /= 1) $ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) @@ -221,6 +221,7 @@ genStructName = \t -> "ty_" ++ gen t where gen STNil = "n" gen (STPair a b) = 'P' : gen a ++ gen b gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b gen (STMaybe t) = 'M' : gen t gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t gen (STScal st) = case st of @@ -230,7 +231,6 @@ genStructName = \t -> "ty_" ++ gen t where STF64 -> "d" STBool -> "b" gen (STAccum t) = 'C' : gen (fromSMTy t) - gen (STLEither a b) = 'L' : gen a ++ gen b -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the @@ -246,6 +246,8 @@ genStruct name topty = case topty of [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] STEither a b -> -- 0 -> l, 1 -> r [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] STMaybe t -> -- 0 -> nothing, 1 -> just [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] STArr n t -> @@ -259,8 +261,6 @@ genStruct name topty = case topty of STAccum t -> [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" ,StructDecl name (name ++ "_buf *buf;") com] - STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] where com = ppSTy 0 topty @@ -282,11 +282,11 @@ genStructs ty = do STNil -> pure () STPair a b -> genStructs a >> genStructs b STEither a b -> genStructs a >> genStructs b + STLEither a b -> genStructs a >> genStructs b STMaybe t -> genStructs t STArr _ t -> genStructs t STScal _ -> pure () STAccum t -> genStructs (fromSMTy t) - STLEither a b -> genStructs a >> genStructs b tell (BList (genStruct name ty)) @@ -463,6 +463,15 @@ serialise topty topval ptr off k = (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k (STMaybe _, Nothing) -> do pokeByteOff ptr off (0 :: Word8) k @@ -493,15 +502,6 @@ serialise topty topval ptr off k = STF64 -> pokeByteOff ptr off (x :: Double) >> k STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k (STAccum{}, _) -> error "Cannot serialise accumulators" - (STLEither _ _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STLEither a _, Just (Left x)) -> do - pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STLEither _ b, Just (Right y)) -> do - pokeByteOff ptr off (2 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k -- | Assumes that this is called at the correct alignment. deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) @@ -518,6 +518,13 @@ deserialise topty ptr off = if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" STMaybe t -> do tag <- peekByteOff @Word8 ptr off if tag == 0 @@ -541,13 +548,6 @@ deserialise topty ptr off = STF64 -> peekByteOff @Double ptr off STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off STAccum{} -> error "Cannot serialise accumulators" - STLEither a b -> do - tag <- peekByteOff @Word8 ptr off - case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) - 0 -> return Nothing - 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) - 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) - _ -> error "Invalid tag value" align :: Int -> Int -> Int align a off = (off + a - 1) `div` a * a @@ -569,6 +569,10 @@ metricsSTy (STEither a b) = let (a1, s1) = metricsSTy a (a2, s2) = metricsSTy b in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned @@ -580,10 +584,6 @@ metricsSTy (STScal sty) = case sty of STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t metricsSTy (STAccum t) = metricsSTy (fromSMTy t) -metricsSTy (STLEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -1071,8 +1071,8 @@ compile' env = \case incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend emit $ SAsg v (CELit addend) -- sparse types - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") + (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") -- dense types (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do f (v++".a") (i++".a") @@ -1303,13 +1303,13 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop makeArrayTree (STAccum _) = ATNoop -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop - (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = @@ -1657,6 +1657,14 @@ zeroRefcountCheck toptyp opname topvar = go (STEither a b) path = do (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) go (STMaybe a) path = do ss <- go a (path++".j") return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty @@ -1673,14 +1681,6 @@ zeroRefcountCheck toptyp opname topvar = return (BList [s1, s2, s3]) go STScal{} _ = empty go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" - go (STLEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ - SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) - s1 - (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) - s2 - mempty)) combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) combine (MaybeT a) (MaybeT b) = MaybeT $ do |