diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | Makefile | 31 | ||||
-rw-r--r-- | command.c | 121 | ||||
-rw-r--r-- | command.h | 8 | ||||
-rw-r--r-- | conn_data.c | 16 | ||||
-rw-r--r-- | conn_data.h | 13 | ||||
-rw-r--r-- | db.c | 61 | ||||
-rw-r--r-- | db.h | 39 | ||||
-rw-r--r-- | global.c | 33 | ||||
-rw-r--r-- | global.h | 14 | ||||
-rw-r--r-- | main.c | 123 | ||||
-rw-r--r-- | memory.c | 9 | ||||
-rw-r--r-- | memory.h | 12 | ||||
-rw-r--r-- | runloop.c | 53 | ||||
-rw-r--r-- | runloop.h | 10 | ||||
-rw-r--r-- | schema.sql | 43 |
16 files changed, 590 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3623175 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.o +*.sql.h +tomsg_server +db.db diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8922662 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +CC = gcc +CFLAGS = -Wall -Wextra -std=c11 -g -fwrapv +LDFLAGS = -lsqlite3 + +TARGETS = tomsg_server + +.PHONY: all clean remake + +# Clear all implicit suffix rules +.SUFFIXES: + +# Don't delete intermediate files +.SECONDARY: + +all: $(TARGETS) + +clean: + rm -f $(TARGETS) *.o *.sql.h + +remake: clean + $(MAKE) all + + +$(TARGETS): $(patsubst %.c,%.o,$(wildcard *.c)) + $(CC) -o $@ $^ $(LDFLAGS) + +%.o: %.c $(wildcard *.h) $(patsubst %.sql,%.sql.h,$(wildcard *.sql)) + $(CC) $(CFLAGS) -c -o $@ $< + +%.sql.h: %.sql + xxd -i $^ $@ diff --git a/command.c b/command.c new file mode 100644 index 0000000..deaeabe --- /dev/null +++ b/command.c @@ -0,0 +1,121 @@ +#define _GNU_SOURCE +#include <stdio.h> +#include <string.h> +#include <errno.h> +#include <assert.h> +#include <sys/socket.h> +#include "command.h" +#include "db.h" + + +static bool send_raw_text(int fd,const char *text,i64 len){ + i64 cursor=0; + while(cursor<len){ + i64 nwr=send(fd,text+cursor,len-cursor,0); + if(nwr<0){ + if(errno==EINTR)continue; + if(errno==ECONNRESET||errno==EPIPE)return true; + die_perror("send"); + } + cursor+=nwr; + } + return false; +} + +static bool send_ok(int fd,const char *tag){ + char *buf=NULL; + i64 len=asprintf(&buf,"%s ok\n",tag); + assert(buf); + bool closed=send_raw_text(fd,buf,len); + free(buf); + return closed; +} + +static bool send_error(int fd,const char *tag,const char *msg){ + char *buf=NULL; + i64 len=asprintf(&buf,"%s error %s\n",tag,msg); + assert(buf); + bool closed=send_raw_text(fd,buf,len); + free(buf); + return closed; +} + + +static bool cmd_register(int fd,const char *tag,const char **args){ + i64 userid=db_find_user(args[0]); + if(userid!=-1){ + send_error(fd,tag,"Username already exists"); + return false; + } + db_create_user(args[0],args[1]); + return send_ok(fd,tag); +} + +static bool cmd_login(int fd,const char *tag,const char **args){ + (void)fd; (void)tag; (void)args; + return true; +} + + +struct cmd_info{ + const char *cmdname; + int nargs; + bool longlast; // whether the last argument should span the rest of the input line + bool (*handler)(int fd,const char *tag,const char **args); +}; + +static const struct cmd_info commands[]={ + {"register",2,false,cmd_register}, + {"login",2,false,cmd_login}, +}; +#define NCOMMANDS (sizeof(commands)/sizeof(commands[0])) + + +bool handle_input_line(int fd,char *line,size_t linelen){ + char *sepp=memchr(line,' ',linelen); + if(sepp==NULL){ + debug("No space in input line from connection %d",fd); + return true; + } + char *tag=line; + size_t taglen=sepp-tag; + *sepp='\0'; + line+=taglen+1; + linelen-=taglen+1; + + sepp=memchr(line,' ',linelen); + if(sepp==NULL)sepp=line+linelen; + size_t cmdlen=sepp-line; + size_t cmdi; + for(cmdi=0;cmdi<NCOMMANDS;cmdi++){ + if(strncmp(line,commands[cmdi].cmdname,cmdlen)==0){ + break; + } + } + + int nargs=commands[cmdi].nargs; + char *args[nargs]; + size_t cursor=cmdlen+1; + + for(int i=0;i<nargs;i++){ + if(cursor>=linelen){ + debug("Connection %d sent too few parameters to command %s",fd,commands[cmdi].cmdname); + return true; + } + if(i==nargs-1&&commands[cmdi].longlast){ + sepp=line+linelen; + } else { + sepp=memchr(line+cursor,' ',linelen-cursor); + if(sepp==NULL)sepp=line+linelen; + } + *sepp='\0'; + args[i]=line+cursor; + cursor=sepp-line+1; + } + if(sepp-line<(i64)linelen){ + debug("Connection %d sent too many parameters to command %s",fd,commands[cmdi].cmdname); + return true; + } + + return commands[cmdi].handler(fd,tag,(const char**)args); +} diff --git a/command.h b/command.h new file mode 100644 index 0000000..9bfac7f --- /dev/null +++ b/command.h @@ -0,0 +1,8 @@ +#pragma once + +#include "global.h" + + +// Returns true if socket should be closed. +// Modifies some bytes in `line`, AS WELL AS line[linelen]! +bool handle_input_line(int fd,char *line,size_t linelen); diff --git a/conn_data.c b/conn_data.c new file mode 100644 index 0000000..1b13cec --- /dev/null +++ b/conn_data.c @@ -0,0 +1,16 @@ +#include "conn_data.h" + + +void conn_data_init(struct conn_data *data,int fd){ + data->fd=fd; + data->bufsz=512; + data->buflen=0; + data->buffer=malloc(data->bufsz,char); +} + +void conn_data_nullify(struct conn_data *data){ + free(data->buffer); + data->buffer=NULL; + data->bufsz=0; + data->buflen=0; +} diff --git a/conn_data.h b/conn_data.h new file mode 100644 index 0000000..871af07 --- /dev/null +++ b/conn_data.h @@ -0,0 +1,13 @@ +#pragma once + +#include "global.h" + + +struct conn_data{ + int fd; + i64 bufsz,buflen; + char *buffer; +}; + +void conn_data_init(struct conn_data *data,int fd); // Initialises buffers +void conn_data_nullify(struct conn_data *data); // Frees buffers but not the data itself @@ -0,0 +1,61 @@ +#include <string.h> +#include <sqlite3.h> +#include "db.h" +#include "schema.sql.h" + + +#define SQLITE(func,...) do{if(sqlite3_##func(__VA_ARGS__)!=SQLITE_OK){die_sqlite("sqlite3_" #func);}}while(0) + + +sqlite3 *database; + +__attribute__((noreturn)) +static void die_sqlite(const char *func){ + die("%s: %s",func,sqlite3_errmsg(database)); +} + + +void db_init(void){ + SQLITE(open_v2,"db.db",&database,SQLITE_OPEN_READWRITE|SQLITE_OPEN_CREATE,NULL); + char *str=malloc(schema_sql_len+1,char); + memcpy(str,schema_sql,schema_sql_len); + str[schema_sql_len]='\0'; + sqlite3_exec(database,str,NULL,NULL,NULL); + free(str); +} + +void db_close(void){ + sqlite3_close(database); +} + + +i64 db_create_user(const char *name,const char *pass){ + sqlite3_stmt *stmt; + SQLITE(prepare_v2,database,"insert into Users (username, pass) values (?, ?)",-1,&stmt,NULL); + SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); + SQLITE(bind_text,stmt,2,pass,-1,SQLITE_STATIC); + if(sqlite3_step(stmt)!=SQLITE_DONE)die_sqlite("sqlite3_step"); + SQLITE(finalize,stmt); + + i64 userid=sqlite3_last_insert_rowid(database); + + SQLITE(prepare_v2,database,"insert into UserNames (name, user) values (?, ?)",-1,&stmt,NULL); + SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); + SQLITE(bind_int64,stmt,2,userid); + if(sqlite3_step(stmt)!=SQLITE_DONE)die_sqlite("sqlite3_step"); + SQLITE(finalize,stmt); + + return userid; +} + +i64 db_find_user(const char *name){ + sqlite3_stmt *stmt; + SQLITE(prepare_v2,database,"select user from UserNames where name = ?",-1,&stmt,NULL); + SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); + i64 userid=-1; + if(sqlite3_step(stmt)==SQLITE_ROW){ + userid=sqlite3_column_int64(stmt,1); + } + SQLITE(finalize,stmt); + return userid; +} @@ -0,0 +1,39 @@ +#pragma once + +#include "global.h" + + +struct db_name_id{ + char *name; + i64 id; +}; + +struct db_message{ + i64 roomid,userid,timestamp; + char *message; +}; + +struct db_message_list{ + i64 count; + struct db_message *list; +}; + +void db_init(void); +void db_close(void); + +struct db_name_id db_create_room(void); +bool db_delete_room(i64 roomid); +bool db_add_member(i64 roomid,i64 userid); +bool db_remove_member(i64 roomid,u64 userid); +i64 db_find_room(const char *name); // -1 if not found + +i64 db_create_user(const char *name,const char *pass); +bool db_set_username(i64 userid,const char *name); +bool db_set_pass(i64 userid,const char *pass); +char* db_get_username(i64 userid); +char* db_get_pass(i64 userid); +bool db_delete_user(i64 userid); +i64 db_find_user(const char *name); // -1 if not found + +bool db_create_message(i64 roomid,i64 userid,i64 timestamp,const char *message); +struct db_message_list db_get_messages(i64 roomid,i64 timestamp,i64 count); // pass timestamp==-1 for last messages diff --git a/global.c b/global.c new file mode 100644 index 0000000..bf5bf2e --- /dev/null +++ b/global.c @@ -0,0 +1,33 @@ +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <stdarg.h> +#include <errno.h> +#include "global.h" + +__attribute__((noreturn, format(printf, 1, 2))) +void die(const char *format,...){ + fprintf(stderr,"DIE: "); + va_list ap; + va_start(ap,format); + vfprintf(stderr,format,ap); + va_end(ap); + fputc('\n',stderr); + exit(1); +} + +__attribute__((noreturn)) +void die_perror(const char *func){ + fprintf(stderr,"DIE: %s: %s\n",func,strerror(errno)); + exit(1); +} + +__attribute__((format (printf, 1, 2))) +void debug(const char *format,...){ + fprintf(stderr,"DEBUG: "); + va_list ap; + va_start(ap,format); + vfprintf(stderr,format,ap); + va_end(ap); + fputc('\n',stderr); +} diff --git a/global.h b/global.h new file mode 100644 index 0000000..de0873b --- /dev/null +++ b/global.h @@ -0,0 +1,14 @@ +#pragma once + +#include <stdbool.h> +#include <stdint.h> +#include <inttypes.h> +#include "memory.h" + +typedef int64_t i64; +typedef uint64_t u64; + +void die(const char *format,...) __attribute__((noreturn, format(printf, 1, 2))); +void die_perror(const char *func) __attribute__((noreturn)); + +void debug(const char *format,...) __attribute__((format(printf, 1, 2))); @@ -0,0 +1,123 @@ +#include <stdio.h> +#include <string.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <unistd.h> +#include <errno.h> +#include <assert.h> +#include "command.h" +#include "conn_data.h" +#include "db.h" +#include "runloop.h" + +#define PORT (29536) // python: int("msg",36) + + +static int create_server_socket(void){ + int sock=socket(AF_INET,SOCK_STREAM,0); + if(sock<0)die_perror("socket"); + int one=1; + setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,&one,sizeof one); + + struct sockaddr_in name; + name.sin_family=AF_INET; + name.sin_addr.s_addr=htonl(INADDR_ANY); + name.sin_port=htons(PORT); + if(bind(sock,(struct sockaddr*)&name,sizeof name)<0)die_perror("bind"); + + if(listen(sock,16)<0)die_perror("listen"); + return sock; +} + +struct hash_item{ + struct conn_data cd; + struct hash_item *next; +}; + +#define CONN_HASH_SIZE (16) +static struct hash_item *conn_hash[CONN_HASH_SIZE]; + +static struct conn_data* find_conn_data(int fd){ + struct hash_item *item=conn_hash[fd%CONN_HASH_SIZE]; + while(item&&item->cd.fd!=fd)item=item->next; + assert(item); + return &item->cd; +} + +static void delete_conn_data(int fd){ + struct hash_item *item=conn_hash[fd%CONN_HASH_SIZE]; + assert(item); + if(item->cd.fd==fd){ + conn_hash[fd%CONN_HASH_SIZE]=item->next; + conn_data_nullify(&item->cd); + free(item); + return; + } + struct hash_item *parent=NULL; + while(item&&item->cd.fd!=fd){ + parent=item; + item=item->next; + } + assert(parent); + assert(item); + conn_data_nullify(&item->cd); + parent->next=item->next; + free(item); +} + +static bool client_socket_callback(int fd){ + struct conn_data *data=find_conn_data(fd); + if(data->bufsz-data->buflen<256){ + data->bufsz*=2; + data->buffer=realloc(data->buffer,data->bufsz,char); + } + + ssize_t ret; + do ret=read(fd,data->buffer+data->buflen,data->bufsz-data->buflen); + while(ret<0&&errno==EINTR); + if(ret<0)die_perror("read"); + if(ret==0){ + delete_conn_data(fd); + close(fd); + return true; + } + data->buflen+=ret; + + char *lfp=memchr(data->buffer,'\n',data->buflen); + if(lfp==NULL)return false; + size_t length=lfp-data->buffer; + bool should_close=handle_input_line(fd,data->buffer,length); + memmove(data->buffer,lfp+1,data->buflen-length-1); + data->buflen-=length+1; + + if(should_close){ + delete_conn_data(fd); + close(fd); + return true; + } + return false; +} + +static bool server_socket_callback(int fd){ + int sock; + do sock=accept(fd,NULL,NULL); + while(sock<0&&errno==EINTR); + if(sock<0)die_perror("accept"); + runloop_add_fd(sock,client_socket_callback); + + struct hash_item *item=malloc(1,struct hash_item); + conn_data_init(&item->cd,sock); + item->next=conn_hash[sock%CONN_HASH_SIZE]; + conn_hash[sock%CONN_HASH_SIZE]=item; + debug("Added conn_data for fd=%d",sock); + return false; +} + +int main(void){ + db_init(); + int sock=create_server_socket(); + printf("Listening on port %d\n",PORT); + runloop_add_fd(sock,server_socket_callback); + runloop_run(); + db_close(); +} diff --git a/memory.c b/memory.c new file mode 100644 index 0000000..952bedf --- /dev/null +++ b/memory.c @@ -0,0 +1,9 @@ +#include "global.h" +#include "memory.h" + +void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr){ + if(ptr==NULL){ + die("Allocation failed: %s(%zu * %zuB = %zu)",func,num,sz,num*sz); + } + return ptr; +} diff --git a/memory.h b/memory.h new file mode 100644 index 0000000..970648f --- /dev/null +++ b/memory.h @@ -0,0 +1,12 @@ +#pragma once + +#include <stdlib.h> + +#define malloc(num,type) \ + ((type*)check_after_allocation("malloc",num,sizeof(type),malloc((num)*sizeof(type)))) +#define calloc(num,type) \ + ((type*)check_after_allocation("calloc",num,sizeof(type),calloc((num),sizeof(type)))) +#define realloc(ptr,num,type) \ + ((type*)check_after_allocation("realloc",num,sizeof(type),realloc((ptr),(num)*sizeof(type)))) + +void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr); diff --git a/runloop.c b/runloop.c new file mode 100644 index 0000000..d2865a2 --- /dev/null +++ b/runloop.c @@ -0,0 +1,53 @@ +#include <string.h> +#include <sys/select.h> +#include "runloop.h" + + +struct fd_list_item{ + int fd; + runloop_callback *func; +}; + +static struct fd_list_item *fd_list; +static size_t fd_list_len,fd_list_cap; + +__attribute__((constructor)) +static void constructor(void){ + fd_list_cap=16; + fd_list_len=0; + fd_list=malloc(fd_list_cap,struct fd_list_item); +} + + +void runloop_add_fd(int fd,runloop_callback *func){ + if(fd_list_len==fd_list_cap){ + fd_list_cap*=2; + fd_list=realloc(fd_list,fd_list_cap,struct fd_list_item); + } + fd_list[fd_list_len].fd=fd; + fd_list[fd_list_len].func=func; + fd_list_len++; +} + +void runloop_run(void){ + while(true){ + fd_set inset; + FD_ZERO(&inset); + int maxfd=-1; + for(size_t i=0;i<fd_list_len;i++){ + FD_SET(fd_list[i].fd,&inset); + if(fd_list[i].fd>maxfd)maxfd=fd_list[i].fd; + } + int ret=select(maxfd+1,&inset,NULL,NULL,NULL); + if(ret<=0)die_perror("select"); + for(size_t i=0;i<fd_list_len;i++){ + if(FD_ISSET(fd_list[i].fd,&inset)){ + if(fd_list[i].func(fd_list[i].fd)){ + memmove(fd_list+i,fd_list+i+1,fd_list_len-i-1); + i--; + fd_list_len--; + } + } + } + } +} diff --git a/runloop.h b/runloop.h new file mode 100644 index 0000000..fc1f7e8 --- /dev/null +++ b/runloop.h @@ -0,0 +1,10 @@ +#pragma once + +#include "global.h" + + +// Return true to remove fd from runloop +typedef bool runloop_callback(int fd); + +void runloop_add_fd(int fd,runloop_callback *func); +void runloop_run(void); // Returns when empty diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..5a888d6 --- /dev/null +++ b/schema.sql @@ -0,0 +1,43 @@ +pragma foreign_keys = on; + +create table Rooms ( + id integer primary key, + name text +); + +create table RoomNames ( + name text primary key, + room integer not null, + foreign key(room) references Rooms(id) on delete cascade +); + +create table Members ( + room integer, + user integer, + primary key(room, user), + foreign key(room) references Rooms(id) on delete cascade, + foreign key(user) references Users(id) on delete cascade +); + +create table Users ( + id integer primary key, + username text, + pass text +); + +create table UserNames ( + name text primary key, + user integer, + foreign key(user) references Users(id) on delete cascade +); + +create table Messages ( + id integer primary key, + room integer not null, + user integer null, + time integer not null, + message text, + foreign key(room) references Rooms(id) on delete cascade, + foreign key(user) references Users(id) on delete set null +); +create index messages_time_index on Messages(time); |