diff options
| -rw-r--r-- | .gitignore | 4 | ||||
| -rw-r--r-- | Makefile | 31 | ||||
| -rw-r--r-- | command.c | 121 | ||||
| -rw-r--r-- | command.h | 8 | ||||
| -rw-r--r-- | conn_data.c | 16 | ||||
| -rw-r--r-- | conn_data.h | 13 | ||||
| -rw-r--r-- | db.c | 61 | ||||
| -rw-r--r-- | db.h | 39 | ||||
| -rw-r--r-- | global.c | 33 | ||||
| -rw-r--r-- | global.h | 14 | ||||
| -rw-r--r-- | main.c | 123 | ||||
| -rw-r--r-- | memory.c | 9 | ||||
| -rw-r--r-- | memory.h | 12 | ||||
| -rw-r--r-- | runloop.c | 53 | ||||
| -rw-r--r-- | runloop.h | 10 | ||||
| -rw-r--r-- | schema.sql | 43 | 
16 files changed, 590 insertions, 0 deletions
| diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3623175 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.o +*.sql.h +tomsg_server +db.db diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8922662 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +CC = gcc +CFLAGS = -Wall -Wextra -std=c11 -g -fwrapv +LDFLAGS = -lsqlite3 + +TARGETS = tomsg_server + +.PHONY: all clean remake + +# Clear all implicit suffix rules +.SUFFIXES: + +# Don't delete intermediate files +.SECONDARY: + +all: $(TARGETS) + +clean: +	rm -f $(TARGETS) *.o *.sql.h + +remake: clean +	$(MAKE) all + + +$(TARGETS): $(patsubst %.c,%.o,$(wildcard *.c)) +	$(CC) -o $@ $^ $(LDFLAGS) + +%.o: %.c $(wildcard *.h) $(patsubst %.sql,%.sql.h,$(wildcard *.sql)) +	$(CC) $(CFLAGS) -c -o $@ $< + +%.sql.h: %.sql +	xxd -i $^ $@ diff --git a/command.c b/command.c new file mode 100644 index 0000000..deaeabe --- /dev/null +++ b/command.c @@ -0,0 +1,121 @@ +#define _GNU_SOURCE +#include <stdio.h> +#include <string.h> +#include <errno.h> +#include <assert.h> +#include <sys/socket.h> +#include "command.h" +#include "db.h" + + +static bool send_raw_text(int fd,const char *text,i64 len){ +	i64 cursor=0; +	while(cursor<len){ +		i64 nwr=send(fd,text+cursor,len-cursor,0); +		if(nwr<0){ +			if(errno==EINTR)continue; +			if(errno==ECONNRESET||errno==EPIPE)return true; +			die_perror("send"); +		} +		cursor+=nwr; +	} +	return false; +} + +static bool send_ok(int fd,const char *tag){ +	char *buf=NULL; +	i64 len=asprintf(&buf,"%s ok\n",tag); +	assert(buf); +	bool closed=send_raw_text(fd,buf,len); +	free(buf); +	return closed; +} + +static bool send_error(int fd,const char *tag,const char *msg){ +	char *buf=NULL; +	i64 len=asprintf(&buf,"%s error %s\n",tag,msg); +	assert(buf); +	bool closed=send_raw_text(fd,buf,len); +	free(buf); +	return closed; +} + + +static bool cmd_register(int fd,const char *tag,const char **args){ +	i64 userid=db_find_user(args[0]); +	if(userid!=-1){ +		send_error(fd,tag,"Username already exists"); +		return false; +	} +	db_create_user(args[0],args[1]); +	return send_ok(fd,tag); +} + +static bool cmd_login(int fd,const char *tag,const char **args){ +	(void)fd; (void)tag; (void)args; +	return true; +} + + +struct cmd_info{ +	const char *cmdname; +	int nargs; +	bool longlast;  // whether the last argument should span the rest of the input line +	bool (*handler)(int fd,const char *tag,const char **args); +}; + +static const struct cmd_info commands[]={ +	{"register",2,false,cmd_register}, +	{"login",2,false,cmd_login}, +}; +#define NCOMMANDS (sizeof(commands)/sizeof(commands[0])) + + +bool handle_input_line(int fd,char *line,size_t linelen){ +	char *sepp=memchr(line,' ',linelen); +	if(sepp==NULL){ +		debug("No space in input line from connection %d",fd); +		return true; +	} +	char *tag=line; +	size_t taglen=sepp-tag; +	*sepp='\0'; +	line+=taglen+1; +	linelen-=taglen+1; + +	sepp=memchr(line,' ',linelen); +	if(sepp==NULL)sepp=line+linelen; +	size_t cmdlen=sepp-line; +	size_t cmdi; +	for(cmdi=0;cmdi<NCOMMANDS;cmdi++){ +		if(strncmp(line,commands[cmdi].cmdname,cmdlen)==0){ +			break; +		} +	} + +	int nargs=commands[cmdi].nargs; +	char *args[nargs]; +	size_t cursor=cmdlen+1; + +	for(int i=0;i<nargs;i++){ +		if(cursor>=linelen){ +			debug("Connection %d sent too few parameters to command %s",fd,commands[cmdi].cmdname); +			return true; +		} +		if(i==nargs-1&&commands[cmdi].longlast){ +			sepp=line+linelen; +		} else { +			sepp=memchr(line+cursor,' ',linelen-cursor); +			if(sepp==NULL)sepp=line+linelen; +		} +		*sepp='\0'; +		args[i]=line+cursor; +		cursor=sepp-line+1; +	} +	if(sepp-line<(i64)linelen){ +		debug("Connection %d sent too many parameters to command %s",fd,commands[cmdi].cmdname); +		return true; +	} + +	return commands[cmdi].handler(fd,tag,(const char**)args); +} diff --git a/command.h b/command.h new file mode 100644 index 0000000..9bfac7f --- /dev/null +++ b/command.h @@ -0,0 +1,8 @@ +#pragma once + +#include "global.h" + + +// Returns true if socket should be closed. +// Modifies some bytes in `line`, AS WELL AS line[linelen]! +bool handle_input_line(int fd,char *line,size_t linelen); diff --git a/conn_data.c b/conn_data.c new file mode 100644 index 0000000..1b13cec --- /dev/null +++ b/conn_data.c @@ -0,0 +1,16 @@ +#include "conn_data.h" + + +void conn_data_init(struct conn_data *data,int fd){ +	data->fd=fd; +	data->bufsz=512; +	data->buflen=0; +	data->buffer=malloc(data->bufsz,char); +} + +void conn_data_nullify(struct conn_data *data){ +	free(data->buffer); +	data->buffer=NULL; +	data->bufsz=0; +	data->buflen=0; +} diff --git a/conn_data.h b/conn_data.h new file mode 100644 index 0000000..871af07 --- /dev/null +++ b/conn_data.h @@ -0,0 +1,13 @@ +#pragma once + +#include "global.h" + + +struct conn_data{ +	int fd; +	i64 bufsz,buflen; +	char *buffer; +}; + +void conn_data_init(struct conn_data *data,int fd);  // Initialises buffers +void conn_data_nullify(struct conn_data *data);  // Frees buffers but not the data itself @@ -0,0 +1,61 @@ +#include <string.h> +#include <sqlite3.h> +#include "db.h" +#include "schema.sql.h" + + +#define SQLITE(func,...) do{if(sqlite3_##func(__VA_ARGS__)!=SQLITE_OK){die_sqlite("sqlite3_" #func);}}while(0) + + +sqlite3 *database; + +__attribute__((noreturn)) +static void die_sqlite(const char *func){ +	die("%s: %s",func,sqlite3_errmsg(database)); +} + + +void db_init(void){ +	SQLITE(open_v2,"db.db",&database,SQLITE_OPEN_READWRITE|SQLITE_OPEN_CREATE,NULL); +	char *str=malloc(schema_sql_len+1,char); +	memcpy(str,schema_sql,schema_sql_len); +	str[schema_sql_len]='\0'; +	sqlite3_exec(database,str,NULL,NULL,NULL); +	free(str); +} + +void db_close(void){ +	sqlite3_close(database); +} + + +i64 db_create_user(const char *name,const char *pass){ +	sqlite3_stmt *stmt; +	SQLITE(prepare_v2,database,"insert into Users (username, pass) values (?, ?)",-1,&stmt,NULL); +	SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); +	SQLITE(bind_text,stmt,2,pass,-1,SQLITE_STATIC); +	if(sqlite3_step(stmt)!=SQLITE_DONE)die_sqlite("sqlite3_step"); +	SQLITE(finalize,stmt); + +	i64 userid=sqlite3_last_insert_rowid(database); + +	SQLITE(prepare_v2,database,"insert into UserNames (name, user) values (?, ?)",-1,&stmt,NULL); +	SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); +	SQLITE(bind_int64,stmt,2,userid); +	if(sqlite3_step(stmt)!=SQLITE_DONE)die_sqlite("sqlite3_step"); +	SQLITE(finalize,stmt); + +	return userid; +} + +i64 db_find_user(const char *name){ +	sqlite3_stmt *stmt; +	SQLITE(prepare_v2,database,"select user from UserNames where name = ?",-1,&stmt,NULL); +	SQLITE(bind_text,stmt,1,name,-1,SQLITE_STATIC); +	i64 userid=-1; +	if(sqlite3_step(stmt)==SQLITE_ROW){ +		userid=sqlite3_column_int64(stmt,1); +	} +	SQLITE(finalize,stmt); +	return userid; +} @@ -0,0 +1,39 @@ +#pragma once + +#include "global.h" + + +struct db_name_id{ +	char *name; +	i64 id; +}; + +struct db_message{ +	i64 roomid,userid,timestamp; +	char *message; +}; + +struct db_message_list{ +	i64 count; +	struct db_message *list; +}; + +void db_init(void); +void db_close(void); + +struct db_name_id db_create_room(void); +bool db_delete_room(i64 roomid); +bool db_add_member(i64 roomid,i64 userid); +bool db_remove_member(i64 roomid,u64 userid); +i64 db_find_room(const char *name);  // -1 if not found + +i64 db_create_user(const char *name,const char *pass); +bool db_set_username(i64 userid,const char *name); +bool db_set_pass(i64 userid,const char *pass); +char* db_get_username(i64 userid); +char* db_get_pass(i64 userid); +bool db_delete_user(i64 userid); +i64 db_find_user(const char *name);  // -1 if not found + +bool db_create_message(i64 roomid,i64 userid,i64 timestamp,const char *message); +struct db_message_list db_get_messages(i64 roomid,i64 timestamp,i64 count);  // pass timestamp==-1 for last messages diff --git a/global.c b/global.c new file mode 100644 index 0000000..bf5bf2e --- /dev/null +++ b/global.c @@ -0,0 +1,33 @@ +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <stdarg.h> +#include <errno.h> +#include "global.h" + +__attribute__((noreturn, format(printf, 1, 2))) +void die(const char *format,...){ +	fprintf(stderr,"DIE: "); +	va_list ap; +	va_start(ap,format); +	vfprintf(stderr,format,ap); +	va_end(ap); +	fputc('\n',stderr); +	exit(1); +} + +__attribute__((noreturn)) +void die_perror(const char *func){ +	fprintf(stderr,"DIE: %s: %s\n",func,strerror(errno)); +	exit(1); +} + +__attribute__((format (printf, 1, 2))) +void debug(const char *format,...){ +	fprintf(stderr,"DEBUG: "); +	va_list ap; +	va_start(ap,format); +	vfprintf(stderr,format,ap); +	va_end(ap); +	fputc('\n',stderr); +} diff --git a/global.h b/global.h new file mode 100644 index 0000000..de0873b --- /dev/null +++ b/global.h @@ -0,0 +1,14 @@ +#pragma once + +#include <stdbool.h> +#include <stdint.h> +#include <inttypes.h> +#include "memory.h" + +typedef int64_t i64; +typedef uint64_t u64; + +void die(const char *format,...) __attribute__((noreturn, format(printf, 1, 2))); +void die_perror(const char *func) __attribute__((noreturn)); + +void debug(const char *format,...) __attribute__((format(printf, 1, 2))); @@ -0,0 +1,123 @@ +#include <stdio.h> +#include <string.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <unistd.h> +#include <errno.h> +#include <assert.h> +#include "command.h" +#include "conn_data.h" +#include "db.h" +#include "runloop.h" + +#define PORT (29536)  // python: int("msg",36) + + +static int create_server_socket(void){ +	int sock=socket(AF_INET,SOCK_STREAM,0); +	if(sock<0)die_perror("socket"); +	int one=1; +	setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,&one,sizeof one); + +	struct sockaddr_in name; +	name.sin_family=AF_INET; +	name.sin_addr.s_addr=htonl(INADDR_ANY); +	name.sin_port=htons(PORT); +	if(bind(sock,(struct sockaddr*)&name,sizeof name)<0)die_perror("bind"); + +	if(listen(sock,16)<0)die_perror("listen"); +	return sock; +} + +struct hash_item{ +	struct conn_data cd; +	struct hash_item *next; +}; + +#define CONN_HASH_SIZE (16) +static struct hash_item *conn_hash[CONN_HASH_SIZE]; + +static struct conn_data* find_conn_data(int fd){ +	struct hash_item *item=conn_hash[fd%CONN_HASH_SIZE]; +	while(item&&item->cd.fd!=fd)item=item->next; +	assert(item); +	return &item->cd; +} + +static void delete_conn_data(int fd){ +	struct hash_item *item=conn_hash[fd%CONN_HASH_SIZE]; +	assert(item); +	if(item->cd.fd==fd){ +		conn_hash[fd%CONN_HASH_SIZE]=item->next; +		conn_data_nullify(&item->cd); +		free(item); +		return; +	} +	struct hash_item *parent=NULL; +	while(item&&item->cd.fd!=fd){ +		parent=item; +		item=item->next; +	} +	assert(parent); +	assert(item); +	conn_data_nullify(&item->cd); +	parent->next=item->next; +	free(item); +} + +static bool client_socket_callback(int fd){ +	struct conn_data *data=find_conn_data(fd); +	if(data->bufsz-data->buflen<256){ +		data->bufsz*=2; +		data->buffer=realloc(data->buffer,data->bufsz,char); +	} + +	ssize_t ret; +	do ret=read(fd,data->buffer+data->buflen,data->bufsz-data->buflen); +	while(ret<0&&errno==EINTR); +	if(ret<0)die_perror("read"); +	if(ret==0){ +		delete_conn_data(fd); +		close(fd); +		return true; +	} +	data->buflen+=ret; + +	char *lfp=memchr(data->buffer,'\n',data->buflen); +	if(lfp==NULL)return false; +	size_t length=lfp-data->buffer; +	bool should_close=handle_input_line(fd,data->buffer,length); +	memmove(data->buffer,lfp+1,data->buflen-length-1); +	data->buflen-=length+1; + +	if(should_close){ +		delete_conn_data(fd); +		close(fd); +		return true; +	} +	return false; +} + +static bool server_socket_callback(int fd){ +	int sock; +	do sock=accept(fd,NULL,NULL); +	while(sock<0&&errno==EINTR); +	if(sock<0)die_perror("accept"); +	runloop_add_fd(sock,client_socket_callback); + +	struct hash_item *item=malloc(1,struct hash_item); +	conn_data_init(&item->cd,sock); +	item->next=conn_hash[sock%CONN_HASH_SIZE]; +	conn_hash[sock%CONN_HASH_SIZE]=item; +	debug("Added conn_data for fd=%d",sock); +	return false; +} + +int main(void){ +	db_init(); +	int sock=create_server_socket(); +	printf("Listening on port %d\n",PORT); +	runloop_add_fd(sock,server_socket_callback); +	runloop_run(); +	db_close(); +} diff --git a/memory.c b/memory.c new file mode 100644 index 0000000..952bedf --- /dev/null +++ b/memory.c @@ -0,0 +1,9 @@ +#include "global.h" +#include "memory.h" + +void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr){ +	if(ptr==NULL){ +		die("Allocation failed: %s(%zu * %zuB = %zu)",func,num,sz,num*sz); +	} +	return ptr; +} diff --git a/memory.h b/memory.h new file mode 100644 index 0000000..970648f --- /dev/null +++ b/memory.h @@ -0,0 +1,12 @@ +#pragma once + +#include <stdlib.h> + +#define malloc(num,type) \ +	((type*)check_after_allocation("malloc",num,sizeof(type),malloc((num)*sizeof(type)))) +#define calloc(num,type) \ +	((type*)check_after_allocation("calloc",num,sizeof(type),calloc((num),sizeof(type)))) +#define realloc(ptr,num,type) \ +	((type*)check_after_allocation("realloc",num,sizeof(type),realloc((ptr),(num)*sizeof(type)))) + +void* check_after_allocation(const char *func,size_t num,size_t sz,void *ptr); diff --git a/runloop.c b/runloop.c new file mode 100644 index 0000000..d2865a2 --- /dev/null +++ b/runloop.c @@ -0,0 +1,53 @@ +#include <string.h> +#include <sys/select.h> +#include "runloop.h" + + +struct fd_list_item{ +	int fd; +	runloop_callback *func; +}; + +static struct fd_list_item *fd_list; +static size_t fd_list_len,fd_list_cap; + +__attribute__((constructor)) +static void constructor(void){ +	fd_list_cap=16; +	fd_list_len=0; +	fd_list=malloc(fd_list_cap,struct fd_list_item); +} + + +void runloop_add_fd(int fd,runloop_callback *func){ +	if(fd_list_len==fd_list_cap){ +		fd_list_cap*=2; +		fd_list=realloc(fd_list,fd_list_cap,struct fd_list_item); +	} +	fd_list[fd_list_len].fd=fd; +	fd_list[fd_list_len].func=func; +	fd_list_len++; +} + +void runloop_run(void){ +	while(true){ +		fd_set inset; +		FD_ZERO(&inset); +		int maxfd=-1; +		for(size_t i=0;i<fd_list_len;i++){ +			FD_SET(fd_list[i].fd,&inset); +			if(fd_list[i].fd>maxfd)maxfd=fd_list[i].fd; +		} +		int ret=select(maxfd+1,&inset,NULL,NULL,NULL); +		if(ret<=0)die_perror("select"); +		for(size_t i=0;i<fd_list_len;i++){ +			if(FD_ISSET(fd_list[i].fd,&inset)){ +				if(fd_list[i].func(fd_list[i].fd)){ +					memmove(fd_list+i,fd_list+i+1,fd_list_len-i-1); +					i--; +					fd_list_len--; +				} +			} +		} +	} +} diff --git a/runloop.h b/runloop.h new file mode 100644 index 0000000..fc1f7e8 --- /dev/null +++ b/runloop.h @@ -0,0 +1,10 @@ +#pragma once + +#include "global.h" + + +// Return true to remove fd from runloop +typedef bool runloop_callback(int fd); + +void runloop_add_fd(int fd,runloop_callback *func); +void runloop_run(void);  // Returns when empty diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..5a888d6 --- /dev/null +++ b/schema.sql @@ -0,0 +1,43 @@ +pragma foreign_keys = on; + +create table Rooms ( +	id integer primary key, +	name text +); + +create table RoomNames ( +	name text primary key, +	room integer not null, +	foreign key(room) references Rooms(id) on delete cascade +); + +create table Members ( +	room integer, +	user integer, +	primary key(room, user), +	foreign key(room) references Rooms(id) on delete cascade, +	foreign key(user) references Users(id) on delete cascade +); + +create table Users ( +	id integer primary key, +	username text, +	pass text +); + +create table UserNames ( +	name text primary key, +	user integer, +	foreign key(user) references Users(id) on delete cascade +); + +create table Messages ( +	id integer primary key, +	room integer not null, +	user integer null, +	time integer not null, +	message text, +	foreign key(room) references Rooms(id) on delete cascade, +	foreign key(user) references Users(id) on delete set null +); +create index messages_time_index on Messages(time); | 
