summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--codegen.hs93
-rw-r--r--test.nl3
2 files changed, 89 insertions, 7 deletions
diff --git a/codegen.hs b/codegen.hs
index 3461f64..5fd32d2 100644
--- a/codegen.hs
+++ b/codegen.hs
@@ -8,6 +8,7 @@ import qualified Data.Map.Strict as Map
import qualified LLVM.General.AST.Type as A
import qualified LLVM.General.AST.Global as A.G
import qualified LLVM.General.AST.Constant as A.C
+import qualified LLVM.General.AST.Float as A.F
-- import qualified LLVM.General.AST.Operand as A
-- import qualified LLVM.General.AST.Name as A
-- import qualified LLVM.General.AST.Instruction as A
@@ -106,7 +107,14 @@ setGlobalVar name label t = do
state $ \s -> ((), s {globalVariables = Map.insert name (t, label) $ globalVariables s})
lookupVar :: Name -> CGMonad (Type, LLName)
-lookupVar name = liftM (fromJust . Map.lookup name . variables) get
+lookupVar name | trace ("Looking up var " ++ name) False = undefined
+lookupVar name = do
+ obj <- get
+ let locfound = Map.lookup name $ variables obj
+ glofound = Map.lookup name $ globalVariables obj
+ if isJust locfound
+ then return $ fromJust locfound
+ else return $ fromJust glofound
lookupGlobalVar :: Name -> CGMonad (Type, LLName)
lookupGlobalVar name = liftM (fromJust . Map.lookup name . globalVariables) get
@@ -301,16 +309,49 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do
(TypePtr _) -> addInstr $ A.Sub False False e1op' e2op' []
(TypeName _) -> undefined
return $ A.LocalReference (toLLVMType t) (A.Name label)
+ Divide -> do
+ e1op' <- castOperand e1op t
+ e2op' <- castOperand e2op t
+ label <- case t of
+ (TypeInt _) -> addInstr $ A.SDiv False e1op' e2op' []
+ (TypeUInt _) -> addInstr $ A.UDiv False e1op' e2op' []
+ TypeFloat -> addInstr $ A.FDiv A.NoFastMathFlags e1op' e2op' []
+ TypeDouble -> addInstr $ A.FDiv A.NoFastMathFlags e1op' e2op' []
+ (TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers"
+ (TypeName _) -> undefined
+ return $ A.LocalReference (toLLVMType t) (A.Name label)
+ Modulo -> do
+ e1op' <- castOperand e1op t
+ e2op' <- castOperand e2op t
+ label <- case t of
+ (TypeInt _) -> addInstr $ A.SRem e1op' e2op' []
+ (TypeUInt _) -> addInstr $ A.URem e1op' e2op' []
+ TypeFloat -> addInstr $ A.FRem A.NoFastMathFlags e1op' e2op' []
+ TypeDouble -> addInstr $ A.FRem A.NoFastMathFlags e1op' e2op' []
+ (TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers"
+ (TypeName _) -> undefined
+ return $ A.LocalReference (toLLVMType t) (A.Name label)
Equal -> do
- sharedType <- commonType (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2))
+ sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2))
+ trace ("Shared type for Equal of " ++ pshow e1 ++ " and " ++ pshow e2 ++ " is: " ++ pshow sharedType)
+ $ return ()
e1op' <- castOperand e1op sharedType
e2op' <- castOperand e2op sharedType
label <- case sharedType of
(TypeInt _) -> addInstr $ A.ICmp A.EQ e1op' e2op' []
_ -> undefined
return $ A.LocalReference (A.IntegerType 1) (A.Name label)
+ Greater -> do
+ sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2))
+ e1op' <- castOperand e1op sharedType
+ e2op' <- castOperand e2op sharedType
+ label <- case sharedType of
+ (TypeInt _) -> addInstr $ A.ICmp A.SGT e1op' e2op' []
+ (TypeUInt _) -> addInstr $ A.ICmp A.UGT e1op' e2op' []
+ _ -> undefined
+ return $ A.LocalReference (A.IntegerType 1) (A.Name label)
Less -> do
- sharedType <- commonType (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2))
+ sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2))
e1op' <- castOperand e1op sharedType
e2op' <- castOperand e2op sharedType
label <- case sharedType of
@@ -318,7 +359,25 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do
(TypeUInt _) -> addInstr $ A.ICmp A.ULT e1op' e2op' []
_ -> undefined
return $ A.LocalReference (A.IntegerType 1) (A.Name label)
+ BoolOr -> do
+ e1op' <- castToBool e1op
+ e2op' <- castToBool e2op
+ label <- addInstr $ A.Or e1op' e2op' []
+ return $ A.LocalReference (A.IntegerType 1) (A.Name label)
_ -> throwError $ "Binary operator " ++ pshow bo ++ " not implemented"
+genExpression (ExUnOp uo e1 (Just t)) = do
+ e1op <- genExprArgument e1
+ case uo of
+ Negate -> do
+ label <- case t of
+ (TypeInt s) -> addInstr $ A.Sub False False (A.ConstantOperand (A.C.Int (fromIntegral s) 0)) e1op []
+ (TypeUInt s) -> addInstr $ A.Sub False False (A.ConstantOperand (A.C.Int (fromIntegral s) 0)) e1op []
+ TypeFloat -> addInstr $ A.FSub A.NoFastMathFlags (A.ConstantOperand (A.C.Float (A.F.Single 0))) e1op []
+ TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags (A.ConstantOperand (A.C.Float (A.F.Double 0))) e1op []
+ (TypePtr _) -> throwError $ "Negate '-' operator not defined on a pointer"
+ (TypeName _) -> undefined
+ return $ A.LocalReference (toLLVMType t) (A.Name label)
+ _ -> throwError $ "Unary operator " ++ pshow uo ++ " not implemented"
genExpression ex = throwError $ "Expression '" ++ pshow ex ++ "' not implemented"
genExprArgument :: Expression -> CGMonad A.Operand
@@ -335,6 +394,9 @@ literalToOperand (LitVar n) t = do
literalToOperand lit _ = throwError $ "Literal '" ++ pshow lit ++ "' not implemented"
castOperand :: A.Operand -> Type -> CGMonad A.Operand
+castOperand orig@(A.LocalReference (A.IntegerType 1) _) t2@(TypeInt _) = do
+ label <- addInstr $ A.ZExt orig (toLLVMType t2) []
+ return $ A.LocalReference (toLLVMType t2) (A.Name label)
castOperand orig@(A.LocalReference (A.IntegerType s1) _) t2@(TypeInt s2)
| fromIntegral s1 == s2 = return orig
| fromIntegral s1 < s2 = do
@@ -365,9 +427,28 @@ castToBool (A.ConstantOperand (A.C.Int _ val)) =
return $ A.ConstantOperand (A.C.Int 1 (if val == 0 then 1 else 0))
castToBool _ = undefined
-commonType :: Type -> Type -> CGMonad Type
-commonType (TypeInt s1) (TypeInt s2) = return $ TypeInt (max s1 s2)
-commonType _ _ = undefined
+
+commonType :: Type -> Type -> Maybe Type
+commonType (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypePtr t1
+commonType (TypePtr _) _ = Nothing
+commonType _ (TypePtr _) = Nothing
+
+commonType (TypeInt s1) (TypeInt s2) = Just $ TypeInt (max s1 s2)
+
+commonType (TypeUInt s1) (TypeUInt s2) = Just $ TypeUInt (max s1 s2)
+
+commonType TypeFloat (TypeInt _) = Just TypeFloat
+commonType (TypeInt _) TypeFloat = Just TypeFloat
+commonType TypeDouble (TypeInt _) = Just TypeDouble
+commonType (TypeInt _) TypeDouble = Just TypeDouble
+commonType TypeFloat TypeDouble = Just TypeDouble
+commonType TypeDouble TypeFloat = Just TypeDouble
+
+commonType _ _ = Nothing
+
+commonTypeM :: Type -> Type -> CGMonad Type
+commonTypeM t1 t2 = maybe err return $ commonType t1 t2
+ where err = throwError $ "Cannot implicitly find common type of '" ++ pshow t1 ++ "' and '" ++ pshow t2 ++ "'"
cleanupTrampolines :: CGMonad ()
diff --git a/test.nl b/test.nl
index 99786fc..5e23bf3 100644
--- a/test.nl
+++ b/test.nl
@@ -1,9 +1,10 @@
type int = i32;
type char = i8;
-int glob = 10;
+int glob;
int main(int argc, ptr(char) argv) {
+ glob = 10;
int kaas = glob + 2;
glob = 2 > 1 || 1 == 1 % 10;
while (glob < 20) {