diff options
Diffstat (limited to 'server/src')
-rw-r--r-- | server/src/constants.rs | 6 | ||||
-rw-r--r-- | server/src/db.rs | 116 | ||||
-rw-r--r-- | server/src/main.rs | 151 | ||||
-rw-r--r-- | server/src/util.rs | 30 |
4 files changed, 159 insertions, 144 deletions
diff --git a/server/src/constants.rs b/server/src/constants.rs new file mode 100644 index 0000000..0f603da --- /dev/null +++ b/server/src/constants.rs @@ -0,0 +1,6 @@ +pub const DB_FILE_NAME: &str = "notes.db"; +pub const DB_SCHEMA_VERSION: i64 = 1; + +pub const LOGIN_TOKEN_EXPIRE_SECS: i64 = 40 * 24 * 3600; +/// Upon a request less than this before the expiration time, the token is refreshed. +pub const LOGIN_TOKEN_REFRESH_MARGIN: i64 = 20 * 24 * 3600; diff --git a/server/src/db.rs b/server/src/db.rs new file mode 100644 index 0000000..2211c42 --- /dev/null +++ b/server/src/db.rs @@ -0,0 +1,116 @@ +use tokio; +use tokio::sync::Mutex; +use std::{ + sync::Arc, + ops::DerefMut, + time::SystemTime, + fs, + process, +}; +use sqlx::{ + ConnectOptions, + SqliteConnection, + Row, +}; +use sqlx::sqlite::{ + SqliteConnectOptions, + SqliteJournalMode, +}; +use rand::RngCore; +use argon2::password_hash::rand_core::OsRng; + +use crate::constants::*; +use crate::util::base64_encode; + +async fn initialise_schema(conn: &mut SqliteConnection) { + match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await { + Ok(_) => {}, + Err(err) => { + eprintln!("Error: Could not initialise database schema: {err}"); + process::exit(1); + }, + } +} + +pub async fn open() -> SqliteConnection { + let opts = SqliteConnectOptions::new() + .filename(DB_FILE_NAME) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal) + .pragma("foreign_keys", "on"); + + let mut conn = match opts.connect().await { + Ok(conn) => conn, + Err(err) => { + eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}"); + process::exit(1); + } + }; + + match sqlx::query("select version from meta").fetch_one(&mut conn).await { + Ok(row) => { + let ver: i64 = row.get(0); + if ver == DB_SCHEMA_VERSION { + conn + } else { + eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}"); + process::exit(1); + } + }, + Err(err) => { + eprintln!("Error: Failed querying database meta version: {err}"); + eprintln!("Initialising schema."); + initialise_schema(&mut conn).await; + conn + }, + } +} + +pub type DB = Arc<Mutex<SqliteConnection>>; + +pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> { + let mut conn = db.lock().await; + match sqlx::query("insert into Users (username, passhash) values ($1, $2)") + .bind(username) + .bind(passhash) + .execute(conn.deref_mut()) + .await { + Ok(_) => Ok(()), + Err(_) => Err("User already exists".to_string()), + } +} + +pub async fn get_passhash(db: DB, username: &str) -> Result<String, ()> { + let mut conn = db.lock().await; + match sqlx::query("select passhash from Users where username = $1") + .bind(username) + .fetch_optional(conn.deref_mut()) + .await { + Ok(Some(row)) => Ok(row.get(0)), + Ok(None) => Err(()), + Err(err) => { + eprintln!("db_get_passhash: sqlx error: {err}"); + Err(()) + } + } +} + +fn generate_login_token() -> String { + let mut bytes = [0u8; 32]; + OsRng.fill_bytes(&mut bytes); + base64_encode(&bytes) +} + +pub async fn create_login_token(db: DB, username: &str) -> Result<String, ()> { + let mut conn = db.lock().await; + let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(); + let token = generate_login_token(); + match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)") + .bind(username) + .bind(&token) + .bind(now as i64 + LOGIN_TOKEN_EXPIRE_SECS) + .execute(conn.deref_mut()).await { + Ok(_) => Ok(token), + Err(_) => Err(()), + } +} diff --git a/server/src/main.rs b/server/src/main.rs index a099a13..11fddb4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -2,19 +2,6 @@ use tokio; use tokio::sync::Mutex; use std::{ sync::Arc, - ops::DerefMut, - time::SystemTime, - fs, - process, -}; -use sqlx::{ - ConnectOptions, - SqliteConnection, - Row, -}; -use sqlx::sqlite::{ - SqliteConnectOptions, - SqliteJournalMode, }; use warp::{Filter, Reply, Rejection}; use argon2::{ @@ -25,134 +12,10 @@ use argon2::{ Argon2 }; use serde::Deserialize; -use rand::RngCore; - -const DB_FILE_NAME: &str = "notes.db"; -const DB_SCHEMA_VERSION: i64 = 1; - -fn base64_encode(bytes: &[u8]) -> String { - const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - - let len = (bytes.len() + 2) / 3 * 4; - let mut dest = Vec::with_capacity(len); - dest.resize(len, 0u8); - for i in 0 .. bytes.len() / 3 { - dest[4 * i + 0] = ALPHABET[(bytes[3 * i + 0] & 0x3f) as usize]; - dest[4 * i + 1] = ALPHABET[(((bytes[3 * i + 0] & 0xc0) >> 6) | ((bytes[3 * i + 1] & 0x0f) << 2)) as usize]; - dest[4 * i + 2] = ALPHABET[(((bytes[3 * i + 1] & 0xf0) >> 4) | ((bytes[3 * i + 2] & 0x03) << 4)) as usize]; - dest[4 * i + 3] = ALPHABET[((bytes[3 * i + 1] & 0xfc) >> 2) as usize]; - } - let last = bytes.len() / 3; - match bytes.len() % 3 { - 1 => { - dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize]; - dest[4 * last + 1] = ALPHABET[((bytes[3 * last + 0] & 0xc0) >> 6) as usize]; - dest[4 * last + 2] = b'='; - dest[4 * last + 3] = b'='; - } - 2 => { - dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize]; - dest[4 * last + 1] = ALPHABET[(((bytes[3 * last + 0] & 0xc0) >> 6) | ((bytes[3 * last + 1] & 0x0f) << 2)) as usize]; - dest[4 * last + 2] = ALPHABET[((bytes[3 * last + 1] & 0xf0) >> 4) as usize]; - dest[4 * last + 3] = b'='; - } - _0 => {} - } - String::from_utf8(dest).unwrap() -} - -async fn db_initialise_schema(conn: &mut SqliteConnection) { - match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await { - Ok(_) => {}, - Err(err) => { - eprintln!("Error: Could not initialise database schema: {err}"); - process::exit(1); - }, - } -} - -async fn open_db() -> SqliteConnection { - let opts = SqliteConnectOptions::new() - .filename(DB_FILE_NAME) - .create_if_missing(true) - .journal_mode(SqliteJournalMode::Wal) - .pragma("foreign_keys", "on"); - - let mut conn = match opts.connect().await { - Ok(conn) => conn, - Err(err) => { - eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}"); - process::exit(1); - } - }; - - match sqlx::query("select version from meta").fetch_one(&mut conn).await { - Ok(row) => { - let ver: i64 = row.get(0); - if ver == DB_SCHEMA_VERSION { - conn - } else { - eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}"); - process::exit(1); - } - }, - Err(err) => { - eprintln!("Error: Failed querying database meta version: {err}"); - eprintln!("Initialising schema."); - db_initialise_schema(&mut conn).await; - conn - }, - } -} - -type DB = Arc<Mutex<SqliteConnection>>; - -async fn db_register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> { - let mut conn = db.lock().await; - match sqlx::query("insert into Users (username, passhash) values ($1, $2)") - .bind(username) - .bind(passhash) - .execute(conn.deref_mut()) - .await { - Ok(_) => Ok(()), - Err(_) => Err("User already exists".to_string()), - } -} -async fn db_get_passhash(db: DB, username: &str) -> Result<String, ()> { - let mut conn = db.lock().await; - match sqlx::query("select passhash from Users where username = $1") - .bind(username) - .fetch_optional(conn.deref_mut()) - .await { - Ok(Some(row)) => Ok(row.get(0)), - Ok(None) => Err(()), - Err(err) => { - eprintln!("db_get_passhash: sqlx error: {err}"); - Err(()) - } - } -} - -fn generate_login_token() -> String { - let mut bytes = [0u8; 32]; - OsRng.fill_bytes(&mut bytes); - base64_encode(&bytes) -} - -async fn db_create_login_token(db: DB, username: &str) -> Result<String, ()> { - let mut conn = db.lock().await; - let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(); - let token = generate_login_token(); - match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)") - .bind(username) - .bind(&token) - .bind(now as i64 + 40 * 24 * 3600) // valid for 40 days - .execute(conn.deref_mut()).await { - Ok(_) => Ok(token), - Err(_) => Err(()), - } -} +mod constants; use constants::*; +mod util; +mod db; use db::DB; async fn handle_ping() -> Result<String, Rejection> { return Ok("pong".to_string()) @@ -184,14 +47,14 @@ async fn handle_register(db: DB, req: RegisterReq) -> Response { Ok(hash) => hash.to_string(), Err(_) => return mk_bad_request!("bad request"), }; - match db_register_account(db, &req.username, &hash.to_string()).await { + match db::register_account(db, &req.username, &hash.to_string()).await { Ok(()) => Ok(Box::new("Registered")), Err(err) => mk_bad_request!(err), } } async fn handle_login(db: DB, req: RegisterReq) -> Result<Box<dyn Reply>, Rejection> { - let passhash = match db_get_passhash(db.clone(), &req.username).await { + let passhash = match db::get_passhash(db.clone(), &req.username).await { Ok(passhash) => passhash, Err(()) => return mk_not_found!("User not found"), }; @@ -205,7 +68,7 @@ async fn handle_login(db: DB, req: RegisterReq) -> Result<Box<dyn Reply>, Reject if let Err(_) = Argon2::default().verify_password(req.password.as_bytes(), &parsed_hash) { return Ok(Box::new(warp::reply::with_status("Incorrect password", warp::http::StatusCode::UNAUTHORIZED))); } - match db_create_login_token(db, &req.username).await { + match db::create_login_token(db, &req.username).await { Ok(token) => Ok(Box::new(token)), Err(()) => { eprintln!("Failed inserting login token for user '{0}'", &req.username); @@ -220,7 +83,7 @@ macro_rules! db_handler1 { #[tokio::main] async fn main() { - let db: DB = Arc::new(Mutex::new(open_db().await)); + let db: DB = Arc::new(Mutex::new(db::open().await)); println!("Opened database at {DB_FILE_NAME}."); let router = diff --git a/server/src/util.rs b/server/src/util.rs new file mode 100644 index 0000000..6a0d8d9 --- /dev/null +++ b/server/src/util.rs @@ -0,0 +1,30 @@ +pub fn base64_encode(bytes: &[u8]) -> String { + const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + let len = (bytes.len() + 2) / 3 * 4; + let mut dest = Vec::with_capacity(len); + dest.resize(len, 0u8); + for i in 0 .. bytes.len() / 3 { + dest[4 * i + 0] = ALPHABET[(bytes[3 * i + 0] & 0x3f) as usize]; + dest[4 * i + 1] = ALPHABET[(((bytes[3 * i + 0] & 0xc0) >> 6) | ((bytes[3 * i + 1] & 0x0f) << 2)) as usize]; + dest[4 * i + 2] = ALPHABET[(((bytes[3 * i + 1] & 0xf0) >> 4) | ((bytes[3 * i + 2] & 0x03) << 4)) as usize]; + dest[4 * i + 3] = ALPHABET[((bytes[3 * i + 1] & 0xfc) >> 2) as usize]; + } + let last = bytes.len() / 3; + match bytes.len() % 3 { + 1 => { + dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize]; + dest[4 * last + 1] = ALPHABET[((bytes[3 * last + 0] & 0xc0) >> 6) as usize]; + dest[4 * last + 2] = b'='; + dest[4 * last + 3] = b'='; + } + 2 => { + dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize]; + dest[4 * last + 1] = ALPHABET[(((bytes[3 * last + 0] & 0xc0) >> 6) | ((bytes[3 * last + 1] & 0x0f) << 2)) as usize]; + dest[4 * last + 2] = ALPHABET[((bytes[3 * last + 1] & 0xf0) >> 4) as usize]; + dest[4 * last + 3] = b'='; + } + _0 => {} + } + String::from_utf8(dest).unwrap() +} |