summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/constants.rs6
-rw-r--r--server/src/db.rs116
-rw-r--r--server/src/main.rs151
-rw-r--r--server/src/util.rs30
4 files changed, 159 insertions, 144 deletions
diff --git a/server/src/constants.rs b/server/src/constants.rs
new file mode 100644
index 0000000..0f603da
--- /dev/null
+++ b/server/src/constants.rs
@@ -0,0 +1,6 @@
+pub const DB_FILE_NAME: &str = "notes.db";
+pub const DB_SCHEMA_VERSION: i64 = 1;
+
+pub const LOGIN_TOKEN_EXPIRE_SECS: i64 = 40 * 24 * 3600;
+/// Upon a request less than this before the expiration time, the token is refreshed.
+pub const LOGIN_TOKEN_REFRESH_MARGIN: i64 = 20 * 24 * 3600;
diff --git a/server/src/db.rs b/server/src/db.rs
new file mode 100644
index 0000000..2211c42
--- /dev/null
+++ b/server/src/db.rs
@@ -0,0 +1,116 @@
+use tokio;
+use tokio::sync::Mutex;
+use std::{
+ sync::Arc,
+ ops::DerefMut,
+ time::SystemTime,
+ fs,
+ process,
+};
+use sqlx::{
+ ConnectOptions,
+ SqliteConnection,
+ Row,
+};
+use sqlx::sqlite::{
+ SqliteConnectOptions,
+ SqliteJournalMode,
+};
+use rand::RngCore;
+use argon2::password_hash::rand_core::OsRng;
+
+use crate::constants::*;
+use crate::util::base64_encode;
+
+async fn initialise_schema(conn: &mut SqliteConnection) {
+ match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await {
+ Ok(_) => {},
+ Err(err) => {
+ eprintln!("Error: Could not initialise database schema: {err}");
+ process::exit(1);
+ },
+ }
+}
+
+pub async fn open() -> SqliteConnection {
+ let opts = SqliteConnectOptions::new()
+ .filename(DB_FILE_NAME)
+ .create_if_missing(true)
+ .journal_mode(SqliteJournalMode::Wal)
+ .pragma("foreign_keys", "on");
+
+ let mut conn = match opts.connect().await {
+ Ok(conn) => conn,
+ Err(err) => {
+ eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}");
+ process::exit(1);
+ }
+ };
+
+ match sqlx::query("select version from meta").fetch_one(&mut conn).await {
+ Ok(row) => {
+ let ver: i64 = row.get(0);
+ if ver == DB_SCHEMA_VERSION {
+ conn
+ } else {
+ eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}");
+ process::exit(1);
+ }
+ },
+ Err(err) => {
+ eprintln!("Error: Failed querying database meta version: {err}");
+ eprintln!("Initialising schema.");
+ initialise_schema(&mut conn).await;
+ conn
+ },
+ }
+}
+
+pub type DB = Arc<Mutex<SqliteConnection>>;
+
+pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> {
+ let mut conn = db.lock().await;
+ match sqlx::query("insert into Users (username, passhash) values ($1, $2)")
+ .bind(username)
+ .bind(passhash)
+ .execute(conn.deref_mut())
+ .await {
+ Ok(_) => Ok(()),
+ Err(_) => Err("User already exists".to_string()),
+ }
+}
+
+pub async fn get_passhash(db: DB, username: &str) -> Result<String, ()> {
+ let mut conn = db.lock().await;
+ match sqlx::query("select passhash from Users where username = $1")
+ .bind(username)
+ .fetch_optional(conn.deref_mut())
+ .await {
+ Ok(Some(row)) => Ok(row.get(0)),
+ Ok(None) => Err(()),
+ Err(err) => {
+ eprintln!("db_get_passhash: sqlx error: {err}");
+ Err(())
+ }
+ }
+}
+
+fn generate_login_token() -> String {
+ let mut bytes = [0u8; 32];
+ OsRng.fill_bytes(&mut bytes);
+ base64_encode(&bytes)
+}
+
+pub async fn create_login_token(db: DB, username: &str) -> Result<String, ()> {
+ let mut conn = db.lock().await;
+ let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
+ let token = generate_login_token();
+ match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)")
+ .bind(username)
+ .bind(&token)
+ .bind(now as i64 + LOGIN_TOKEN_EXPIRE_SECS)
+ .execute(conn.deref_mut()).await {
+ Ok(_) => Ok(token),
+ Err(_) => Err(()),
+ }
+}
diff --git a/server/src/main.rs b/server/src/main.rs
index a099a13..11fddb4 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -2,19 +2,6 @@ use tokio;
use tokio::sync::Mutex;
use std::{
sync::Arc,
- ops::DerefMut,
- time::SystemTime,
- fs,
- process,
-};
-use sqlx::{
- ConnectOptions,
- SqliteConnection,
- Row,
-};
-use sqlx::sqlite::{
- SqliteConnectOptions,
- SqliteJournalMode,
};
use warp::{Filter, Reply, Rejection};
use argon2::{
@@ -25,134 +12,10 @@ use argon2::{
Argon2
};
use serde::Deserialize;
-use rand::RngCore;
-
-const DB_FILE_NAME: &str = "notes.db";
-const DB_SCHEMA_VERSION: i64 = 1;
-
-fn base64_encode(bytes: &[u8]) -> String {
- const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
-
- let len = (bytes.len() + 2) / 3 * 4;
- let mut dest = Vec::with_capacity(len);
- dest.resize(len, 0u8);
- for i in 0 .. bytes.len() / 3 {
- dest[4 * i + 0] = ALPHABET[(bytes[3 * i + 0] & 0x3f) as usize];
- dest[4 * i + 1] = ALPHABET[(((bytes[3 * i + 0] & 0xc0) >> 6) | ((bytes[3 * i + 1] & 0x0f) << 2)) as usize];
- dest[4 * i + 2] = ALPHABET[(((bytes[3 * i + 1] & 0xf0) >> 4) | ((bytes[3 * i + 2] & 0x03) << 4)) as usize];
- dest[4 * i + 3] = ALPHABET[((bytes[3 * i + 1] & 0xfc) >> 2) as usize];
- }
- let last = bytes.len() / 3;
- match bytes.len() % 3 {
- 1 => {
- dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize];
- dest[4 * last + 1] = ALPHABET[((bytes[3 * last + 0] & 0xc0) >> 6) as usize];
- dest[4 * last + 2] = b'=';
- dest[4 * last + 3] = b'=';
- }
- 2 => {
- dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize];
- dest[4 * last + 1] = ALPHABET[(((bytes[3 * last + 0] & 0xc0) >> 6) | ((bytes[3 * last + 1] & 0x0f) << 2)) as usize];
- dest[4 * last + 2] = ALPHABET[((bytes[3 * last + 1] & 0xf0) >> 4) as usize];
- dest[4 * last + 3] = b'=';
- }
- _0 => {}
- }
- String::from_utf8(dest).unwrap()
-}
-
-async fn db_initialise_schema(conn: &mut SqliteConnection) {
- match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await {
- Ok(_) => {},
- Err(err) => {
- eprintln!("Error: Could not initialise database schema: {err}");
- process::exit(1);
- },
- }
-}
-
-async fn open_db() -> SqliteConnection {
- let opts = SqliteConnectOptions::new()
- .filename(DB_FILE_NAME)
- .create_if_missing(true)
- .journal_mode(SqliteJournalMode::Wal)
- .pragma("foreign_keys", "on");
-
- let mut conn = match opts.connect().await {
- Ok(conn) => conn,
- Err(err) => {
- eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}");
- process::exit(1);
- }
- };
-
- match sqlx::query("select version from meta").fetch_one(&mut conn).await {
- Ok(row) => {
- let ver: i64 = row.get(0);
- if ver == DB_SCHEMA_VERSION {
- conn
- } else {
- eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}");
- process::exit(1);
- }
- },
- Err(err) => {
- eprintln!("Error: Failed querying database meta version: {err}");
- eprintln!("Initialising schema.");
- db_initialise_schema(&mut conn).await;
- conn
- },
- }
-}
-
-type DB = Arc<Mutex<SqliteConnection>>;
-
-async fn db_register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> {
- let mut conn = db.lock().await;
- match sqlx::query("insert into Users (username, passhash) values ($1, $2)")
- .bind(username)
- .bind(passhash)
- .execute(conn.deref_mut())
- .await {
- Ok(_) => Ok(()),
- Err(_) => Err("User already exists".to_string()),
- }
-}
-async fn db_get_passhash(db: DB, username: &str) -> Result<String, ()> {
- let mut conn = db.lock().await;
- match sqlx::query("select passhash from Users where username = $1")
- .bind(username)
- .fetch_optional(conn.deref_mut())
- .await {
- Ok(Some(row)) => Ok(row.get(0)),
- Ok(None) => Err(()),
- Err(err) => {
- eprintln!("db_get_passhash: sqlx error: {err}");
- Err(())
- }
- }
-}
-
-fn generate_login_token() -> String {
- let mut bytes = [0u8; 32];
- OsRng.fill_bytes(&mut bytes);
- base64_encode(&bytes)
-}
-
-async fn db_create_login_token(db: DB, username: &str) -> Result<String, ()> {
- let mut conn = db.lock().await;
- let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
- let token = generate_login_token();
- match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)")
- .bind(username)
- .bind(&token)
- .bind(now as i64 + 40 * 24 * 3600) // valid for 40 days
- .execute(conn.deref_mut()).await {
- Ok(_) => Ok(token),
- Err(_) => Err(()),
- }
-}
+mod constants; use constants::*;
+mod util;
+mod db; use db::DB;
async fn handle_ping() -> Result<String, Rejection> {
return Ok("pong".to_string())
@@ -184,14 +47,14 @@ async fn handle_register(db: DB, req: RegisterReq) -> Response {
Ok(hash) => hash.to_string(),
Err(_) => return mk_bad_request!("bad request"),
};
- match db_register_account(db, &req.username, &hash.to_string()).await {
+ match db::register_account(db, &req.username, &hash.to_string()).await {
Ok(()) => Ok(Box::new("Registered")),
Err(err) => mk_bad_request!(err),
}
}
async fn handle_login(db: DB, req: RegisterReq) -> Result<Box<dyn Reply>, Rejection> {
- let passhash = match db_get_passhash(db.clone(), &req.username).await {
+ let passhash = match db::get_passhash(db.clone(), &req.username).await {
Ok(passhash) => passhash,
Err(()) => return mk_not_found!("User not found"),
};
@@ -205,7 +68,7 @@ async fn handle_login(db: DB, req: RegisterReq) -> Result<Box<dyn Reply>, Reject
if let Err(_) = Argon2::default().verify_password(req.password.as_bytes(), &parsed_hash) {
return Ok(Box::new(warp::reply::with_status("Incorrect password", warp::http::StatusCode::UNAUTHORIZED)));
}
- match db_create_login_token(db, &req.username).await {
+ match db::create_login_token(db, &req.username).await {
Ok(token) => Ok(Box::new(token)),
Err(()) => {
eprintln!("Failed inserting login token for user '{0}'", &req.username);
@@ -220,7 +83,7 @@ macro_rules! db_handler1 {
#[tokio::main]
async fn main() {
- let db: DB = Arc::new(Mutex::new(open_db().await));
+ let db: DB = Arc::new(Mutex::new(db::open().await));
println!("Opened database at {DB_FILE_NAME}.");
let router =
diff --git a/server/src/util.rs b/server/src/util.rs
new file mode 100644
index 0000000..6a0d8d9
--- /dev/null
+++ b/server/src/util.rs
@@ -0,0 +1,30 @@
+pub fn base64_encode(bytes: &[u8]) -> String {
+ const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+ let len = (bytes.len() + 2) / 3 * 4;
+ let mut dest = Vec::with_capacity(len);
+ dest.resize(len, 0u8);
+ for i in 0 .. bytes.len() / 3 {
+ dest[4 * i + 0] = ALPHABET[(bytes[3 * i + 0] & 0x3f) as usize];
+ dest[4 * i + 1] = ALPHABET[(((bytes[3 * i + 0] & 0xc0) >> 6) | ((bytes[3 * i + 1] & 0x0f) << 2)) as usize];
+ dest[4 * i + 2] = ALPHABET[(((bytes[3 * i + 1] & 0xf0) >> 4) | ((bytes[3 * i + 2] & 0x03) << 4)) as usize];
+ dest[4 * i + 3] = ALPHABET[((bytes[3 * i + 1] & 0xfc) >> 2) as usize];
+ }
+ let last = bytes.len() / 3;
+ match bytes.len() % 3 {
+ 1 => {
+ dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize];
+ dest[4 * last + 1] = ALPHABET[((bytes[3 * last + 0] & 0xc0) >> 6) as usize];
+ dest[4 * last + 2] = b'=';
+ dest[4 * last + 3] = b'=';
+ }
+ 2 => {
+ dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize];
+ dest[4 * last + 1] = ALPHABET[(((bytes[3 * last + 0] & 0xc0) >> 6) | ((bytes[3 * last + 1] & 0x0f) << 2)) as usize];
+ dest[4 * last + 2] = ALPHABET[((bytes[3 * last + 1] & 0xf0) >> 4) as usize];
+ dest[4 * last + 3] = b'=';
+ }
+ _0 => {}
+ }
+ String::from_utf8(dest).unwrap()
+}