diff options
-rw-r--r-- | command.c | 123 | ||||
-rw-r--r-- | command.h | 3 | ||||
-rw-r--r-- | conn_data.c | 4 | ||||
-rw-r--r-- | conn_data.h | 2 | ||||
-rw-r--r-- | db.c | 107 | ||||
-rw-r--r-- | db.h | 10 | ||||
-rw-r--r-- | main.c | 7 | ||||
-rw-r--r-- | memory.c | 7 | ||||
-rw-r--r-- | memory.h | 3 | ||||
-rw-r--r-- | schema.sql | 7 |
10 files changed, 249 insertions, 24 deletions
@@ -40,20 +40,108 @@ static bool send_error(int fd,const char *tag,const char *msg){ return closed; } +static bool send_name(int fd,const char *tag,const char *name){ + char *buf=NULL; + i64 len=asprintf(&buf,"%s name %s\n",tag,name); + assert(buf); + bool closed=send_raw_text(fd,buf,len); + free(buf); + return closed; +} -static bool cmd_register(int fd,const char *tag,const char **args){ +static bool send_list(int fd,const char *tag,i64 count,const char **list){ + char *buf=NULL; + i64 len=asprintf(&buf,"%s list %lld",tag,count); + assert(buf); + bool closed=send_raw_text(fd,buf,len); + free(buf); + if(closed)return true; + + if(count>0){ + i64 bufsz=64; + buf=malloc(bufsz,char); + + for(i64 i=0;i<count;i++){ + i64 len=strlen(list[i]); + if(len>=bufsz){ + bufsz=len+512; + buf=realloc(buf,bufsz,char); + } + memcpy(buf+1,list[i],len); + buf[0]=' '; + if(send_raw_text(fd,buf,len+1)){ + free(buf); + return true; + } + } + + free(buf); + } + + return send_raw_text(fd,"\n",1); +} + + +static bool cmd_register(struct conn_data *data,const char *tag,const char **args){ i64 userid=db_find_user(args[0]); if(userid!=-1){ - send_error(fd,tag,"Username already exists"); + send_error(data->fd,tag,"Username already exists"); return false; } db_create_user(args[0],args[1]); - return send_ok(fd,tag); + return send_ok(data->fd,tag); } -static bool cmd_login(int fd,const char *tag,const char **args){ - (void)fd; (void)tag; (void)args; - return true; +static bool cmd_login(struct conn_data *data,const char *tag,const char **args){ + i64 userid=db_find_user(args[0]); + if(userid==-1){ + send_error(data->fd,tag,"Username does not exist"); + return false; + } + char *pass=db_get_pass(userid); + bool success=strcmp(args[1],pass)==0; + free(pass); + if(success){ + data->userid=userid; + send_ok(data->fd,tag); + } else { + data->userid=-1; + send_error(data->fd,tag,"Incorrect password"); + } + return false; +} + +static bool cmd_list_rooms(struct conn_data *data,const char *tag,const char **args){ + (void)args; + if(data->userid==-1){ + send_error(data->fd,tag,"Not logged in"); + return false; + } + struct db_room_list rl=db_list_rooms(data->userid); + if(rl.count<=0){ + db_nullify_room_list(rl); + return send_list(data->fd,tag,0,NULL); + } + const char *names[rl.count]; + for(i64 i=0;i<rl.count;i++){ + names[i]=rl.list[i].name; + } + bool closed=send_list(data->fd,tag,rl.count,names); + db_nullify_room_list(rl); + return closed; +} + +static bool cmd_create_room(struct conn_data *data,const char *tag,const char **args){ + (void)args; + if(data->userid==-1){ + send_error(data->fd,tag,"Not logged in"); + return false; + } + struct db_name_id room=db_create_room(); + db_add_member(room.id,data->userid); + bool closed=send_name(data->fd,tag,room.name); + db_nullify_name_id(room); + return closed; } @@ -61,20 +149,22 @@ struct cmd_info{ const char *cmdname; int nargs; bool longlast; // whether the last argument should span the rest of the input line - bool (*handler)(int fd,const char *tag,const char **args); + bool (*handler)(struct conn_data *data,const char *tag,const char **args); }; static const struct cmd_info commands[]={ - {"register",2,false,cmd_register}, - {"login",2,false,cmd_login}, + {"register",2,true,cmd_register}, + {"login",2,true,cmd_login}, + {"list_rooms",0,false,cmd_list_rooms}, + {"create_room",0,false,cmd_create_room}, }; #define NCOMMANDS (sizeof(commands)/sizeof(commands[0])) -bool handle_input_line(int fd,char *line,size_t linelen){ +bool handle_input_line(struct conn_data *data,char *line,size_t linelen){ char *sepp=memchr(line,' ',linelen); if(sepp==NULL){ - debug("No space in input line from connection %d",fd); + debug("No space in input line from connection %d",data->fd); return true; } char *tag=line; @@ -93,13 +183,18 @@ bool handle_input_line(int fd,char *line,size_t linelen){ } } + if(cmdi==NCOMMANDS){ + debug("Unknown command %s on connection %d",line,data->fd); + return true; + } + int nargs=commands[cmdi].nargs; char *args[nargs]; size_t cursor=cmdlen+1; for(int i=0;i<nargs;i++){ if(cursor>=linelen){ - debug("Connection %d sent too few parameters to command %s",fd,commands[cmdi].cmdname); + debug("Connection %d sent too few parameters to command %s",data->fd,commands[cmdi].cmdname); return true; } if(i==nargs-1&&commands[cmdi].longlast){ @@ -113,9 +208,9 @@ bool handle_input_line(int fd,char *line,size_t linelen){ cursor=sepp-line+1; } if(sepp-line<(i64)linelen){ - debug("Connection %d sent too many parameters to command %s",fd,commands[cmdi].cmdname); + debug("Connection %d sent too many parameters to command %s",data->fd,commands[cmdi].cmdname); return true; } - return commands[cmdi].handler(fd,tag,(const char**)args); + return commands[cmdi].handler(data,tag,(const char**)args); } @@ -1,8 +1,9 @@ #pragma once #include "global.h" +#include "conn_data.h" // Returns true if socket should be closed. // Modifies some bytes in `line`, AS WELL AS line[linelen]! -bool handle_input_line(int fd,char *line,size_t linelen); +bool handle_input_line(struct conn_data *data,char *line,size_t linelen); diff --git a/conn_data.c b/conn_data.c index 1b13cec..1b237b4 100644 --- a/conn_data.c +++ b/conn_data.c @@ -6,6 +6,8 @@ void conn_data_init(struct conn_data *data,int fd){ data->bufsz=512; data->buflen=0; data->buffer=malloc(data->bufsz,char); + + data->userid=-1; } void conn_data_nullify(struct conn_data *data){ @@ -13,4 +15,6 @@ void conn_data_nullify(struct conn_data *data){ data->buffer=NULL; data->bufsz=0; data->buflen=0; + + data->userid=-1; } diff --git a/conn_data.h b/conn_data.h index 871af07..dae100f 100644 --- a/conn_data.h +++ b/conn_data.h @@ -7,6 +7,8 @@ struct conn_data{ int fd; i64 bufsz,buflen; char *buffer; + + i64 userid; // -1 if not logged in }; void conn_data_init(struct conn_data *data,int fd); // Initialises buffers @@ -1,4 +1,6 @@ +#include <stdlib.h> #include <string.h> +#include <assert.h> #include <sqlite3.h> #include "db.h" #include "schema.sql.h" @@ -29,6 +31,82 @@ void db_close(void){ } +static char* gen_room_name(void){ + const int name_len=8; + const char *alphabet="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + const int alpha_len=strlen(alphabet); + + char *name=malloc(name_len+1,char); + for(int i=0;i<name_len;i++){ + name[i]=alphabet[random()%alpha_len]; + } + name[name_len]='\0'; + debug("Generated name: %s",name); + return name; +} + +struct db_name_id db_create_room(void){ + char *name=gen_room_name(); + + sqlite3_stmt *stmt; + SQLITE(prepare_v2,database,"insert into Rooms (name) values (?)",-1,&stmt,NULL); + SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); + bool success=sqlite3_step(stmt)==SQLITE_DONE; + SQLITE(finalize,stmt); + + i64 userid=sqlite3_last_insert_rowid(database); + + if(!success){ + free(name); + return (struct db_name_id){NULL,-1}; + } + + return (struct db_name_id){name,userid}; +} + +bool db_add_member(i64 roomid,i64 userid){ + assert(roomid!=-1&&userid!=-1); + sqlite3_stmt *stmt; + SQLITE(prepare_v2,database,"insert into Members (room, user) values (?, ?)",-1,&stmt,NULL); + SQLITE(bind_int64,stmt,1,roomid); + SQLITE(bind_int64,stmt,2,userid); + bool success=sqlite3_step(stmt)==SQLITE_DONE; + SQLITE(finalize,stmt); + return success; +} + +struct db_room_list db_list_rooms(i64 userid){ + sqlite3_stmt *stmt; + SQLITE(prepare_v2,database, + "select M.room, R.name " + "from Members as M, Rooms as R " + "where M.user = ? and M.room = R.id" + ,-1,&stmt,NULL); + SQLITE(bind_int64,stmt,1,userid); + + struct db_room_list rl; + i64 cap=4; + rl.count=0; + rl.list=malloc(cap,struct db_name_id); + + int ret; + while((ret=sqlite3_step(stmt))==SQLITE_ROW){ + if(rl.count==cap){ + cap*=2; + rl.list=realloc(rl.list,cap,struct db_name_id); + } + rl.list[rl.count].id=sqlite3_column_int64(stmt,0); + rl.list[rl.count].name=strdup((const char*)sqlite3_column_text(stmt,1)); + rl.count++; + } + + if(ret!=SQLITE_DONE)die_sqlite("sqlite3_step"); + SQLITE(finalize,stmt); + + return rl; +} + + i64 db_create_user(const char *name,const char *pass){ sqlite3_stmt *stmt; SQLITE(prepare_v2,database,"insert into Users (username, pass) values (?, ?)",-1,&stmt,NULL); @@ -48,14 +126,41 @@ i64 db_create_user(const char *name,const char *pass){ return userid; } +char* db_get_pass(i64 userid){ + sqlite3_stmt *stmt; + SQLITE(prepare_v2,database,"select pass from Users where id = ?",-1,&stmt,NULL); + SQLITE(bind_int64,stmt,1,userid); + const unsigned char *pass_sq=NULL; + if(sqlite3_step(stmt)==SQLITE_ROW){ + pass_sq=sqlite3_column_text(stmt,0); + } + char *pass=NULL; + if(pass_sq)pass=strdup((const char*)pass_sq); + SQLITE(finalize,stmt); + return pass; +} + i64 db_find_user(const char *name){ sqlite3_stmt *stmt; SQLITE(prepare_v2,database,"select user from UserNames where name = ?",-1,&stmt,NULL); SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); i64 userid=-1; if(sqlite3_step(stmt)==SQLITE_ROW){ - userid=sqlite3_column_int64(stmt,1); + userid=sqlite3_column_int64(stmt,0); } SQLITE(finalize,stmt); return userid; } + + +void db_nullify_name_id(struct db_name_id ni){ + if(ni.name)free(ni.name); +} + +void db_nullify_room_list(struct db_room_list rl){ + for(i64 i=0;i<rl.count;i++){ + free(rl.list[i].name); + } + if(rl.list)free(rl.list); + rl.list=NULL; +} @@ -8,6 +8,11 @@ struct db_name_id{ i64 id; }; +struct db_room_list{ + i64 count; + struct db_name_id *list; +}; + struct db_message{ i64 roomid,userid,timestamp; char *message; @@ -26,6 +31,7 @@ bool db_delete_room(i64 roomid); bool db_add_member(i64 roomid,i64 userid); bool db_remove_member(i64 roomid,u64 userid); i64 db_find_room(const char *name); // -1 if not found +struct db_room_list db_list_rooms(i64 userid); i64 db_create_user(const char *name,const char *pass); bool db_set_username(i64 userid,const char *name); @@ -37,3 +43,7 @@ i64 db_find_user(const char *name); // -1 if not found bool db_create_message(i64 roomid,i64 userid,i64 timestamp,const char *message); struct db_message_list db_get_messages(i64 roomid,i64 timestamp,i64 count); // pass timestamp==-1 for last messages + +void db_nullify_name_id(struct db_name_id ni); +void db_nullify_room_list(struct db_room_list rl); +void db_nullify_message_list(struct db_message_list ml); @@ -1,8 +1,9 @@ #include <stdio.h> +#include <stdlib.h> #include <string.h> +#include <unistd.h> #include <sys/socket.h> #include <netinet/in.h> -#include <unistd.h> #include <errno.h> #include <assert.h> #include "command.h" @@ -86,7 +87,7 @@ static bool client_socket_callback(int fd){ char *lfp=memchr(data->buffer,'\n',data->buflen); if(lfp==NULL)return false; size_t length=lfp-data->buffer; - bool should_close=handle_input_line(fd,data->buffer,length); + bool should_close=handle_input_line(data,data->buffer,length); memmove(data->buffer,lfp+1,data->buflen-length-1); data->buflen-=length+1; @@ -114,6 +115,8 @@ static bool server_socket_callback(int fd){ } int main(void){ + srandomdev(); + db_init(); int sock=create_server_socket(); printf("Listening on port %d\n",PORT); @@ -7,3 +7,10 @@ void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr){ } return ptr; } + +void* check_after_allocation_str(const char *func,void *ptr){ + if(ptr==NULL){ + die("Allocation failed: %s()",func); + } + return ptr; +} @@ -8,5 +8,8 @@ ((type*)check_after_allocation("calloc",num,sizeof(type),calloc((num),sizeof(type)))) #define realloc(ptr,num,type) \ ((type*)check_after_allocation("realloc",num,sizeof(type),realloc((ptr),(num)*sizeof(type)))) +#define strdup(str) \ + ((char*)check_after_allocation_str("strdup",strdup(str))) void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr); +void* check_after_allocation_str(const char *func,void *ptr); @@ -4,12 +4,7 @@ create table Rooms ( id integer primary key, name text ); - -create table RoomNames ( - name text primary key, - room integer not null, - foreign key(room) references Rooms(id) on delete cascade -); +create unique index rooms_name_index on Rooms(name); create table Members ( room integer, |