From b9b2ccd5155f8ce14cc9b4b04fffe56b988a3bdd Mon Sep 17 00:00:00 2001 From: tomsmeding Date: Wed, 1 Feb 2017 23:12:40 +0100 Subject: Pointer arithmetic! --- check.hs | 4 ++- codegen.hs | 104 ++++++++++++++++++++++++++++++++++++++++++----------- nl/string_index.nl | 12 +++++++ 3 files changed, 99 insertions(+), 21 deletions(-) create mode 100644 nl/string_index.nl diff --git a/check.hs b/check.hs index 22d8196..65c1470 100644 --- a/check.hs +++ b/check.hs @@ -214,6 +214,8 @@ complogBO = compareBO ++ logicBO resultTypeBO :: BinaryOperator -> Type -> Type -> Maybe Type resultTypeBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeUInt 1 resultTypeBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeUInt 1 +resultTypeBO bo t@(TypePtr _) (TypeInt _) | bo `elem` [Plus, Minus] = Just t +resultTypeBO bo (TypeInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t resultTypeBO _ (TypePtr _) _ = Nothing resultTypeBO _ _ (TypePtr _) = Nothing @@ -243,7 +245,7 @@ resultTypeUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t resultTypeUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t resultTypeUO Negate TypeFloat = Just TypeFloat resultTypeUO Negate TypeDouble = Just TypeDouble -resultTypeUO Dereference t@(TypePtr _) = Just t +resultTypeUO Dereference (TypePtr t) = Just t resultTypeUO _ _ = Nothing smallestFloatType :: Double -> Type diff --git a/codegen.hs b/codegen.hs index e56caa4..507feb6 100644 --- a/codegen.hs +++ b/codegen.hs @@ -392,19 +392,52 @@ genExpression :: Expression -> CGMonad A.Operand genExpression (ExLit lit (Just t)) = literalToOperand lit t genExpression (ExBinOp bo e1 e2 (Just t)) = do case bo of - Plus -> do - e1op <- genExprArgument e1 >>= flip castOperand t - e2op <- genExprArgument e2 >>= flip castOperand t - label <- case t of - (TypeInt _) -> addInstr $ A.Add False False e1op e2op [] - (TypeUInt _) -> addInstr $ A.Add False False e1op e2op [] - TypeFloat -> addInstr $ A.FAdd A.NoFastMathFlags e1op e2op [] - TypeDouble -> addInstr $ A.FAdd A.NoFastMathFlags e1op e2op [] - (TypePtr _) -> throwError $ "Plus '+' operator not defined on pointers" - (TypeFunc _ _) -> throwError $ "Plus '+' operator not defined on function pointers" - (TypeName _) -> undefined - TypeVoid -> undefined - return $ A.LocalReference (toLLVMType t) (A.Name label) + Plus -> case (fromJust (exTypeOf e1), fromJust (exTypeOf e2)) of + (ptrtype@(TypePtr _), TypeInt _) -> do + ptrop <- genExprArgument e1 + intop <- genExprArgument e2 >>= flip castOperand (TypeInt 64) + ptrlabel <- addInstr $ A.PtrToInt ptrop (toLLVMType (TypeUInt 64)) [] + add <- addInstr $ A.Add False False (A.LocalReference (toLLVMType ptrtype) (A.Name ptrlabel)) intop [] + res <- addInstr $ A.IntToPtr (A.LocalReference (toLLVMType (TypeUInt 64)) (A.Name add)) + (toLLVMType t) [] + return $ A.LocalReference (toLLVMType ptrtype) (A.Name res) + (ptrtype@(TypePtr _), TypeUInt _) -> do + ptrop <- genExprArgument e1 + intop <- genExprArgument e2 >>= flip castOperand (TypeUInt 64) + ptrlabel <- addInstr $ A.PtrToInt ptrop (toLLVMType (TypeUInt 64)) [] + add <- addInstr $ A.Add False False (A.LocalReference (toLLVMType ptrtype) (A.Name ptrlabel)) intop [] + res <- addInstr $ A.IntToPtr (A.LocalReference (toLLVMType (TypeUInt 64)) (A.Name add)) + (toLLVMType t) [] + return $ A.LocalReference (toLLVMType ptrtype) (A.Name res) + (TypeInt _, ptrtype@(TypePtr _)) -> do + ptrop <- genExprArgument e2 + intop <- genExprArgument e1 >>= flip castOperand (TypeInt 64) + ptrlabel <- addInstr $ A.PtrToInt ptrop (toLLVMType (TypeUInt 64)) [] + add <- addInstr $ A.Add False False (A.LocalReference (toLLVMType ptrtype) (A.Name ptrlabel)) intop [] + res <- addInstr $ A.IntToPtr (A.LocalReference (toLLVMType (TypeUInt 64)) (A.Name add)) + (toLLVMType t) [] + return $ A.LocalReference (toLLVMType ptrtype) (A.Name res) + (TypeUInt _, ptrtype@(TypePtr _)) -> do + ptrop <- genExprArgument e2 + intop <- genExprArgument e1 >>= flip castOperand (TypeUInt 64) + ptrlabel <- addInstr $ A.PtrToInt ptrop (toLLVMType (TypeUInt 64)) [] + add <- addInstr $ A.Add False False (A.LocalReference (toLLVMType ptrtype) (A.Name ptrlabel)) intop [] + res <- addInstr $ A.IntToPtr (A.LocalReference (toLLVMType (TypeUInt 64)) (A.Name add)) + (toLLVMType t) [] + return $ A.LocalReference (toLLVMType t) (A.Name res) + _ -> do + e1op <- genExprArgument e1 >>= flip castOperand t + e2op <- genExprArgument e2 >>= flip castOperand t + label <- case t of + (TypeInt _) -> addInstr $ A.Add False False e1op e2op [] + (TypeUInt _) -> addInstr $ A.Add False False e1op e2op [] + TypeFloat -> addInstr $ A.FAdd A.NoFastMathFlags e1op e2op [] + TypeDouble -> addInstr $ A.FAdd A.NoFastMathFlags e1op e2op [] + (TypePtr _) -> throwError $ "Plus '+' operator not defined on pointers" + (TypeFunc _ _) -> throwError $ "Plus '+' operator not defined on function pointers" + (TypeName _) -> undefined + TypeVoid -> undefined + return $ A.LocalReference (toLLVMType t) (A.Name label) Minus -> do e1op <- genExprArgument e1 >>= flip castOperand t e2op <- genExprArgument e2 >>= flip castOperand t @@ -471,6 +504,20 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (A.IntegerType 1) (A.Name label) + Unequal -> do + sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) + e1op <- genExprArgument e1 >>= flip castOperand sharedType + e2op <- genExprArgument e2 >>= flip castOperand sharedType + label <- case sharedType of + (TypeInt _) -> addInstr $ A.ICmp A.IP.NE e1op e2op [] + (TypeUInt _) -> addInstr $ A.ICmp A.IP.NE e1op e2op [] + TypeFloat -> addInstr $ A.FCmp A.FPP.ONE e1op e2op [] + TypeDouble -> addInstr $ A.FCmp A.FPP.ONE e1op e2op [] + (TypePtr _) -> addInstr $ A.ICmp A.IP.NE e1op e2op [] + (TypeFunc _ _) -> addInstr $ A.ICmp A.IP.NE e1op e2op [] + (TypeName _) -> undefined + TypeVoid -> undefined + return $ A.LocalReference (A.IntegerType 1) (A.Name label) Greater -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) e1op <- genExprArgument e1 >>= flip castOperand sharedType @@ -548,11 +595,27 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do reslabel <- addInstr $ A.Phi A.i1 [(A.ConstantOperand (A.C.Int 1 0), A.Name firstbb), (A.LocalReference A.i1 (A.Name label2), A.Name bb2)] [] return $ A.LocalReference A.i1 (A.Name reslabel) - -- BoolOr -> do - -- e1op' <- castToBool e1op - -- e2op' <- castToBool e2op - -- label <- addInstr $ A.Or e1op' e2op' [] - -- return $ A.LocalReference (A.IntegerType 1) (A.Name label) + BoolOr -> do + firstbb <- liftM (fromJust . currentBlock) get + (A.LocalReference (A.IntegerType 1) (A.Name label1)) <- genExprArgument e1 >>= castToBool + (A.Do origterm) <- getTerminator + + bb2 <- newBlock + (A.LocalReference (A.IntegerType 1) (A.Name label2)) <- genExprArgument e2 >>= castToBool + + bb3 <- newBlock + + changeBlock firstbb + setTerminator $ A.CondBr (A.LocalReference A.i1 (A.Name label1)) (A.Name bb3) (A.Name bb2) [] + + changeBlock bb2 + setTerminator $ A.Br (A.Name bb3) [] + + changeBlock bb3 + setTerminator origterm + reslabel <- addInstr $ A.Phi A.i1 [(A.ConstantOperand (A.C.Int 1 1), A.Name firstbb), + (A.LocalReference A.i1 (A.Name label2), A.Name bb2)] [] + return $ A.LocalReference A.i1 (A.Name reslabel) _ -> throwError $ "Binary operator " ++ pshow bo ++ " not implemented" genExpression (ExUnOp uo e1 (Just t)) = do e1op <- genExprArgument e1 @@ -580,8 +643,9 @@ genExpression (ExUnOp uo e1 (Just t)) = do TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Dereference -> do - label <- case t of - (TypePtr _) -> addInstr $ A.Load False e1op Nothing 0 [] + let (A.LocalReference optype _) = e1op + label <- case optype of + (A.PointerType _ _) -> addInstr $ A.Load False e1op Nothing 0 [] _ -> throwError $ "Dereference '*' operator only defined on pointers" return $ A.LocalReference (toLLVMType t) (A.Name label) _ -> throwError $ "Unary operator " ++ pshow uo ++ " not implemented" diff --git a/nl/string_index.nl b/nl/string_index.nl new file mode 100644 index 0000000..27df14b --- /dev/null +++ b/nl/string_index.nl @@ -0,0 +1,12 @@ +extern func void(i32) putchar; + +i32 main() { + ptr(i8) s = "kaas"; + i32 i = 0; + while (*(s+i) != '\x00') { + putchar(*(s + i)); + i = i + 1; + } + putchar('\n'); + return 0; +} -- cgit v1.2.3-70-g09d2