diff options
Diffstat (limited to 'server/src/db.rs')
-rw-r--r-- | server/src/db.rs | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/server/src/db.rs b/server/src/db.rs new file mode 100644 index 0000000..2211c42 --- /dev/null +++ b/server/src/db.rs @@ -0,0 +1,116 @@ +use tokio; +use tokio::sync::Mutex; +use std::{ + sync::Arc, + ops::DerefMut, + time::SystemTime, + fs, + process, +}; +use sqlx::{ + ConnectOptions, + SqliteConnection, + Row, +}; +use sqlx::sqlite::{ + SqliteConnectOptions, + SqliteJournalMode, +}; +use rand::RngCore; +use argon2::password_hash::rand_core::OsRng; + +use crate::constants::*; +use crate::util::base64_encode; + +async fn initialise_schema(conn: &mut SqliteConnection) { + match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await { + Ok(_) => {}, + Err(err) => { + eprintln!("Error: Could not initialise database schema: {err}"); + process::exit(1); + }, + } +} + +pub async fn open() -> SqliteConnection { + let opts = SqliteConnectOptions::new() + .filename(DB_FILE_NAME) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal) + .pragma("foreign_keys", "on"); + + let mut conn = match opts.connect().await { + Ok(conn) => conn, + Err(err) => { + eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}"); + process::exit(1); + } + }; + + match sqlx::query("select version from meta").fetch_one(&mut conn).await { + Ok(row) => { + let ver: i64 = row.get(0); + if ver == DB_SCHEMA_VERSION { + conn + } else { + eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}"); + process::exit(1); + } + }, + Err(err) => { + eprintln!("Error: Failed querying database meta version: {err}"); + eprintln!("Initialising schema."); + initialise_schema(&mut conn).await; + conn + }, + } +} + +pub type DB = Arc<Mutex<SqliteConnection>>; + +pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> { + let mut conn = db.lock().await; + match sqlx::query("insert into Users (username, passhash) values ($1, $2)") + .bind(username) + .bind(passhash) + .execute(conn.deref_mut()) + .await { + Ok(_) => Ok(()), + Err(_) => Err("User already exists".to_string()), + } +} + +pub async fn get_passhash(db: DB, username: &str) -> Result<String, ()> { + let mut conn = db.lock().await; + match sqlx::query("select passhash from Users where username = $1") + .bind(username) + .fetch_optional(conn.deref_mut()) + .await { + Ok(Some(row)) => Ok(row.get(0)), + Ok(None) => Err(()), + Err(err) => { + eprintln!("db_get_passhash: sqlx error: {err}"); + Err(()) + } + } +} + +fn generate_login_token() -> String { + let mut bytes = [0u8; 32]; + OsRng.fill_bytes(&mut bytes); + base64_encode(&bytes) +} + +pub async fn create_login_token(db: DB, username: &str) -> Result<String, ()> { + let mut conn = db.lock().await; + let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(); + let token = generate_login_token(); + match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)") + .bind(username) + .bind(&token) + .bind(now as i64 + LOGIN_TOKEN_EXPIRE_SECS) + .execute(conn.deref_mut()).await { + Ok(_) => Ok(token), + Err(_) => Err(()), + } +} |