use tokio; use tokio::sync::Mutex; use std::{ sync::Arc, ops::DerefMut, time::SystemTime, fs, process, }; use sqlx::{ ConnectOptions, SqliteConnection, Row, }; use sqlx::sqlite::{ SqliteConnectOptions, SqliteJournalMode, }; use rand::RngCore; use argon2::password_hash::rand_core::OsRng; use crate::constants::*; use crate::util::base64_encode; async fn initialise_schema(conn: &mut SqliteConnection) { match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await { Ok(_) => {}, Err(err) => { eprintln!("Error: Could not initialise database schema: {err}"); process::exit(1); }, } } pub async fn open() -> SqliteConnection { let opts = SqliteConnectOptions::new() .filename(DB_FILE_NAME) .create_if_missing(true) .journal_mode(SqliteJournalMode::Wal) .pragma("foreign_keys", "on"); let mut conn = match opts.connect().await { Ok(conn) => conn, Err(err) => { eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}"); process::exit(1); } }; match sqlx::query("select version from meta").fetch_one(&mut conn).await { Ok(row) => { let ver: i64 = row.get(0); if ver == DB_SCHEMA_VERSION { conn } else { eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}"); process::exit(1); } }, Err(err) => { eprintln!("Error: Failed querying database meta version: {err}"); eprintln!("Initialising schema."); initialise_schema(&mut conn).await; conn }, } } pub type DB = Arc>; pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> { let mut conn = db.lock().await; match sqlx::query("insert into Users (username, passhash) values ($1, $2)") .bind(username) .bind(passhash) .execute(conn.deref_mut()) .await { Ok(_) => Ok(()), Err(_) => Err("User already exists".to_string()), } } pub async fn get_passhash(db: DB, username: &str) -> Result { let mut conn = db.lock().await; match sqlx::query("select passhash from Users where username = $1") .bind(username) .fetch_optional(conn.deref_mut()) .await { Ok(Some(row)) => Ok(row.get(0)), Ok(None) => Err(()), Err(err) => { eprintln!("db_get_passhash: sqlx error: {err}"); Err(()) } } } fn generate_login_token() -> String { let mut bytes = [0u8; 30]; OsRng.fill_bytes(&mut bytes); base64_encode(&bytes) } fn current_time() -> i64 { SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() as i64 } pub async fn create_login_token(db: DB, username: &str) -> Result { let mut conn = db.lock().await; let now = current_time(); let token = generate_login_token(); match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)") .bind(username) .bind(&token) .bind(now + LOGIN_TOKEN_EXPIRE_SECS) .execute(conn.deref_mut()).await { Ok(_) => Ok(token), Err(_) => Err(()), } } pub async fn drop_token(db: DB, token: &str) { let mut conn = db.lock().await; // ignore errors let _ = sqlx::query("delete from Logins where token = $1") .bind(token) .execute(conn.deref_mut()).await; } async fn set_token_expire(conn: &mut SqliteConnection, token: &str, expire: i64) -> Result<(), String> { match sqlx::query("update Logins set expire = $1 where token = $2") .bind(expire) .bind(token) .execute(conn).await { Ok(_) => Ok(()), Err(err) => { eprintln!("set_token_expire: err = {err}"); Err("Server error".to_string()) } } } pub async fn maybe_refresh_token(db: DB, token: &str) -> Result<(), String> { let mut conn = db.lock().await; let now = current_time(); match sqlx::query("select expire from Logins where token = $1") .bind(token) .fetch_optional(conn.deref_mut()).await { Ok(Some(row)) => { if now >= row.get::(0) - LOGIN_TOKEN_REFRESH_MARGIN { set_token_expire(conn.deref_mut(), token, now + LOGIN_TOKEN_EXPIRE_SECS).await } else { Ok(()) } }, Ok(None) => { Err("Not logged in".to_string()) }, Err(err) => { eprintln!("maybe_refresh_token: err = {err}"); Err("Server error".to_string()) } } }