From 230e9775f5b61e21aa085825fbbd0232e9a360ef Mon Sep 17 00:00:00 2001
From: tomsmeding <tom.smeding@gmail.com>
Date: Tue, 14 Mar 2017 13:39:51 +0100
Subject: Room listing and creation

---
 command.c   | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++++-------
 command.h   |   3 +-
 conn_data.c |   4 ++
 conn_data.h |   2 +
 db.c        | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 db.h        |  10 +++++
 main.c      |   7 +++-
 memory.c    |   7 ++++
 memory.h    |   3 ++
 schema.sql  |   7 +---
 10 files changed, 249 insertions(+), 24 deletions(-)

diff --git a/command.c b/command.c
index deaeabe..9d5e2e8 100644
--- a/command.c
+++ b/command.c
@@ -40,20 +40,108 @@ static bool send_error(int fd,const char *tag,const char *msg){
 	return closed;
 }
 
+static bool send_name(int fd,const char *tag,const char *name){
+	char *buf=NULL;
+	i64 len=asprintf(&buf,"%s name %s\n",tag,name);
+	assert(buf);
+	bool closed=send_raw_text(fd,buf,len);
+	free(buf);
+	return closed;
+}
 
-static bool cmd_register(int fd,const char *tag,const char **args){
+static bool send_list(int fd,const char *tag,i64 count,const char **list){
+	char *buf=NULL;
+	i64 len=asprintf(&buf,"%s list %lld",tag,count);
+	assert(buf);
+	bool closed=send_raw_text(fd,buf,len);
+	free(buf);
+	if(closed)return true;
+
+	if(count>0){
+		i64 bufsz=64;
+		buf=malloc(bufsz,char);
+
+		for(i64 i=0;i<count;i++){
+			i64 len=strlen(list[i]);
+			if(len>=bufsz){
+				bufsz=len+512;
+				buf=realloc(buf,bufsz,char);
+			}
+			memcpy(buf+1,list[i],len);
+			buf[0]=' ';
+			if(send_raw_text(fd,buf,len+1)){
+				free(buf);
+				return true;
+			}
+		}
+
+		free(buf);
+	}
+
+	return send_raw_text(fd,"\n",1);
+}
+
+
+static bool cmd_register(struct conn_data *data,const char *tag,const char **args){
 	i64 userid=db_find_user(args[0]);
 	if(userid!=-1){
-		send_error(fd,tag,"Username already exists");
+		send_error(data->fd,tag,"Username already exists");
 		return false;
 	}
 	db_create_user(args[0],args[1]);
-	return send_ok(fd,tag);
+	return send_ok(data->fd,tag);
 }
 
-static bool cmd_login(int fd,const char *tag,const char **args){
-	(void)fd; (void)tag; (void)args;
-	return true;
+static bool cmd_login(struct conn_data *data,const char *tag,const char **args){
+	i64 userid=db_find_user(args[0]);
+	if(userid==-1){
+		send_error(data->fd,tag,"Username does not exist");
+		return false;
+	}
+	char *pass=db_get_pass(userid);
+	bool success=strcmp(args[1],pass)==0;
+	free(pass);
+	if(success){
+		data->userid=userid;
+		send_ok(data->fd,tag);
+	} else {
+		data->userid=-1;
+		send_error(data->fd,tag,"Incorrect password");
+	}
+	return false;
+}
+
+static bool cmd_list_rooms(struct conn_data *data,const char *tag,const char **args){
+	(void)args;
+	if(data->userid==-1){
+		send_error(data->fd,tag,"Not logged in");
+		return false;
+	}
+	struct db_room_list rl=db_list_rooms(data->userid);
+	if(rl.count<=0){
+		db_nullify_room_list(rl);
+		return send_list(data->fd,tag,0,NULL);
+	}
+	const char *names[rl.count];
+	for(i64 i=0;i<rl.count;i++){
+		names[i]=rl.list[i].name;
+	}
+	bool closed=send_list(data->fd,tag,rl.count,names);
+	db_nullify_room_list(rl);
+	return closed;
+}
+
+static bool cmd_create_room(struct conn_data *data,const char *tag,const char **args){
+	(void)args;
+	if(data->userid==-1){
+		send_error(data->fd,tag,"Not logged in");
+		return false;
+	}
+	struct db_name_id room=db_create_room();
+	db_add_member(room.id,data->userid);
+	bool closed=send_name(data->fd,tag,room.name);
+	db_nullify_name_id(room);
+	return closed;
 }
 
 
@@ -61,20 +149,22 @@ struct cmd_info{
 	const char *cmdname;
 	int nargs;
 	bool longlast;  // whether the last argument should span the rest of the input line
-	bool (*handler)(int fd,const char *tag,const char **args);
+	bool (*handler)(struct conn_data *data,const char *tag,const char **args);
 };
 
 static const struct cmd_info commands[]={
-	{"register",2,false,cmd_register},
-	{"login",2,false,cmd_login},
+	{"register",2,true,cmd_register},
+	{"login",2,true,cmd_login},
+	{"list_rooms",0,false,cmd_list_rooms},
+	{"create_room",0,false,cmd_create_room},
 };
 #define NCOMMANDS (sizeof(commands)/sizeof(commands[0]))
 
 
-bool handle_input_line(int fd,char *line,size_t linelen){
+bool handle_input_line(struct conn_data *data,char *line,size_t linelen){
 	char *sepp=memchr(line,' ',linelen);
 	if(sepp==NULL){
-		debug("No space in input line from connection %d",fd);
+		debug("No space in input line from connection %d",data->fd);
 		return true;
 	}
 	char *tag=line;
@@ -93,13 +183,18 @@ bool handle_input_line(int fd,char *line,size_t linelen){
 		}
 	}
 
+	if(cmdi==NCOMMANDS){
+		debug("Unknown command %s on connection %d",line,data->fd);
+		return true;
+	}
+
 	int nargs=commands[cmdi].nargs;
 	char *args[nargs];
 	size_t cursor=cmdlen+1;
 
 	for(int i=0;i<nargs;i++){
 		if(cursor>=linelen){
-			debug("Connection %d sent too few parameters to command %s",fd,commands[cmdi].cmdname);
+			debug("Connection %d sent too few parameters to command %s",data->fd,commands[cmdi].cmdname);
 			return true;
 		}
 		if(i==nargs-1&&commands[cmdi].longlast){
@@ -113,9 +208,9 @@ bool handle_input_line(int fd,char *line,size_t linelen){
 		cursor=sepp-line+1;
 	}
 	if(sepp-line<(i64)linelen){
-		debug("Connection %d sent too many parameters to command %s",fd,commands[cmdi].cmdname);
+		debug("Connection %d sent too many parameters to command %s",data->fd,commands[cmdi].cmdname);
 		return true;
 	}
 
-	return commands[cmdi].handler(fd,tag,(const char**)args);
+	return commands[cmdi].handler(data,tag,(const char**)args);
 }
diff --git a/command.h b/command.h
index 9bfac7f..3b01f42 100644
--- a/command.h
+++ b/command.h
@@ -1,8 +1,9 @@
 #pragma once
 
 #include "global.h"
+#include "conn_data.h"
 
 
 // Returns true if socket should be closed.
 // Modifies some bytes in `line`, AS WELL AS line[linelen]!
-bool handle_input_line(int fd,char *line,size_t linelen);
+bool handle_input_line(struct conn_data *data,char *line,size_t linelen);
diff --git a/conn_data.c b/conn_data.c
index 1b13cec..1b237b4 100644
--- a/conn_data.c
+++ b/conn_data.c
@@ -6,6 +6,8 @@ void conn_data_init(struct conn_data *data,int fd){
 	data->bufsz=512;
 	data->buflen=0;
 	data->buffer=malloc(data->bufsz,char);
+
+	data->userid=-1;
 }
 
 void conn_data_nullify(struct conn_data *data){
@@ -13,4 +15,6 @@ void conn_data_nullify(struct conn_data *data){
 	data->buffer=NULL;
 	data->bufsz=0;
 	data->buflen=0;
+
+	data->userid=-1;
 }
diff --git a/conn_data.h b/conn_data.h
index 871af07..dae100f 100644
--- a/conn_data.h
+++ b/conn_data.h
@@ -7,6 +7,8 @@ struct conn_data{
 	int fd;
 	i64 bufsz,buflen;
 	char *buffer;
+
+	i64 userid;  // -1 if not logged in
 };
 
 void conn_data_init(struct conn_data *data,int fd);  // Initialises buffers
diff --git a/db.c b/db.c
index 3b7a61b..1997c5a 100644
--- a/db.c
+++ b/db.c
@@ -1,4 +1,6 @@
+#include <stdlib.h>
 #include <string.h>
+#include <assert.h>
 #include <sqlite3.h>
 #include "db.h"
 #include "schema.sql.h"
@@ -29,6 +31,82 @@ void db_close(void){
 }
 
 
+static char* gen_room_name(void){
+	const int name_len=8;
+	const char *alphabet="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
+	const int alpha_len=strlen(alphabet);
+
+	char *name=malloc(name_len+1,char);
+	for(int i=0;i<name_len;i++){
+		name[i]=alphabet[random()%alpha_len];
+	}
+	name[name_len]='\0';
+	debug("Generated name: %s",name);
+	return name;
+}
+
+struct db_name_id db_create_room(void){
+	char *name=gen_room_name();
+
+	sqlite3_stmt *stmt;
+	SQLITE(prepare_v2,database,"insert into Rooms (name) values (?)",-1,&stmt,NULL);
+	SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC);
+	bool success=sqlite3_step(stmt)==SQLITE_DONE;
+	SQLITE(finalize,stmt);
+
+	i64 userid=sqlite3_last_insert_rowid(database);
+
+	if(!success){
+		free(name);
+		return (struct db_name_id){NULL,-1};
+	}
+
+	return (struct db_name_id){name,userid};
+}
+
+bool db_add_member(i64 roomid,i64 userid){
+	assert(roomid!=-1&&userid!=-1);
+	sqlite3_stmt *stmt;
+	SQLITE(prepare_v2,database,"insert into Members (room, user) values (?, ?)",-1,&stmt,NULL);
+	SQLITE(bind_int64,stmt,1,roomid);
+	SQLITE(bind_int64,stmt,2,userid);
+	bool success=sqlite3_step(stmt)==SQLITE_DONE;
+	SQLITE(finalize,stmt);
+	return success;
+}
+
+struct db_room_list db_list_rooms(i64 userid){
+	sqlite3_stmt *stmt;
+	SQLITE(prepare_v2,database,
+			"select M.room, R.name "
+			"from Members as M, Rooms as R "
+			"where M.user = ? and M.room = R.id"
+		,-1,&stmt,NULL);
+	SQLITE(bind_int64,stmt,1,userid);
+
+	struct db_room_list rl;
+	i64 cap=4;
+	rl.count=0;
+	rl.list=malloc(cap,struct db_name_id);
+
+	int ret;
+	while((ret=sqlite3_step(stmt))==SQLITE_ROW){
+		if(rl.count==cap){
+			cap*=2;
+			rl.list=realloc(rl.list,cap,struct db_name_id);
+		}
+		rl.list[rl.count].id=sqlite3_column_int64(stmt,0);
+		rl.list[rl.count].name=strdup((const char*)sqlite3_column_text(stmt,1));
+		rl.count++;
+	}
+
+	if(ret!=SQLITE_DONE)die_sqlite("sqlite3_step");
+	SQLITE(finalize,stmt);
+
+	return rl;
+}
+
+
 i64 db_create_user(const char *name,const char *pass){
 	sqlite3_stmt *stmt;
 	SQLITE(prepare_v2,database,"insert into Users (username, pass) values (?, ?)",-1,&stmt,NULL);
@@ -48,14 +126,41 @@ i64 db_create_user(const char *name,const char *pass){
 	return userid;
 }
 
+char* db_get_pass(i64 userid){
+	sqlite3_stmt *stmt;
+	SQLITE(prepare_v2,database,"select pass from Users where id = ?",-1,&stmt,NULL);
+	SQLITE(bind_int64,stmt,1,userid);
+	const unsigned char *pass_sq=NULL;
+	if(sqlite3_step(stmt)==SQLITE_ROW){
+		pass_sq=sqlite3_column_text(stmt,0);
+	}
+	char *pass=NULL;
+	if(pass_sq)pass=strdup((const char*)pass_sq);
+	SQLITE(finalize,stmt);
+	return pass;
+}
+
 i64 db_find_user(const char *name){
 	sqlite3_stmt *stmt;
 	SQLITE(prepare_v2,database,"select user from UserNames where name = ?",-1,&stmt,NULL);
 	SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC);
 	i64 userid=-1;
 	if(sqlite3_step(stmt)==SQLITE_ROW){
-		userid=sqlite3_column_int64(stmt,1);
+		userid=sqlite3_column_int64(stmt,0);
 	}
 	SQLITE(finalize,stmt);
 	return userid;
 }
+
+
+void db_nullify_name_id(struct db_name_id ni){
+	if(ni.name)free(ni.name);
+}
+
+void db_nullify_room_list(struct db_room_list rl){
+	for(i64 i=0;i<rl.count;i++){
+		free(rl.list[i].name);
+	}
+	if(rl.list)free(rl.list);
+	rl.list=NULL;
+}
diff --git a/db.h b/db.h
index 15127e4..911103d 100644
--- a/db.h
+++ b/db.h
@@ -8,6 +8,11 @@ struct db_name_id{
 	i64 id;
 };
 
+struct db_room_list{
+	i64 count;
+	struct db_name_id *list;
+};
+
 struct db_message{
 	i64 roomid,userid,timestamp;
 	char *message;
@@ -26,6 +31,7 @@ bool db_delete_room(i64 roomid);
 bool db_add_member(i64 roomid,i64 userid);
 bool db_remove_member(i64 roomid,u64 userid);
 i64 db_find_room(const char *name);  // -1 if not found
+struct db_room_list db_list_rooms(i64 userid);
 
 i64 db_create_user(const char *name,const char *pass);
 bool db_set_username(i64 userid,const char *name);
@@ -37,3 +43,7 @@ i64 db_find_user(const char *name);  // -1 if not found
 
 bool db_create_message(i64 roomid,i64 userid,i64 timestamp,const char *message);
 struct db_message_list db_get_messages(i64 roomid,i64 timestamp,i64 count);  // pass timestamp==-1 for last messages
+
+void db_nullify_name_id(struct db_name_id ni);
+void db_nullify_room_list(struct db_room_list rl);
+void db_nullify_message_list(struct db_message_list ml);
diff --git a/main.c b/main.c
index 1679f4b..f059778 100644
--- a/main.c
+++ b/main.c
@@ -1,8 +1,9 @@
 #include <stdio.h>
+#include <stdlib.h>
 #include <string.h>
+#include <unistd.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
-#include <unistd.h>
 #include <errno.h>
 #include <assert.h>
 #include "command.h"
@@ -86,7 +87,7 @@ static bool client_socket_callback(int fd){
 	char *lfp=memchr(data->buffer,'\n',data->buflen);
 	if(lfp==NULL)return false;
 	size_t length=lfp-data->buffer;
-	bool should_close=handle_input_line(fd,data->buffer,length);
+	bool should_close=handle_input_line(data,data->buffer,length);
 	memmove(data->buffer,lfp+1,data->buflen-length-1);
 	data->buflen-=length+1;
 
@@ -114,6 +115,8 @@ static bool server_socket_callback(int fd){
 }
 
 int main(void){
+	srandomdev();
+
 	db_init();
 	int sock=create_server_socket();
 	printf("Listening on port %d\n",PORT);
diff --git a/memory.c b/memory.c
index 952bedf..aec3da2 100644
--- a/memory.c
+++ b/memory.c
@@ -7,3 +7,10 @@ void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr){
 	}
 	return ptr;
 }
+
+void* check_after_allocation_str(const char *func,void *ptr){
+	if(ptr==NULL){
+		die("Allocation failed: %s()",func);
+	}
+	return ptr;
+}
diff --git a/memory.h b/memory.h
index 970648f..19cfb95 100644
--- a/memory.h
+++ b/memory.h
@@ -8,5 +8,8 @@
 	((type*)check_after_allocation("calloc",num,sizeof(type),calloc((num),sizeof(type))))
 #define realloc(ptr,num,type) \
 	((type*)check_after_allocation("realloc",num,sizeof(type),realloc((ptr),(num)*sizeof(type))))
+#define strdup(str) \
+	((char*)check_after_allocation_str("strdup",strdup(str)))
 
 void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr);
+void* check_after_allocation_str(const char *func,void *ptr);
diff --git a/schema.sql b/schema.sql
index 5a888d6..9f6da9e 100644
--- a/schema.sql
+++ b/schema.sql
@@ -4,12 +4,7 @@ create table Rooms (
 	id integer primary key,
 	name text
 );
-
-create table RoomNames (
-	name text primary key,
-	room integer not null,
-	foreign key(room) references Rooms(id) on delete cascade
-);
+create unique index rooms_name_index on Rooms(name);
 
 create table Members (
 	room integer,
-- 
cgit v1.2.3-70-g09d2