summaryrefslogtreecommitdiff
path: root/server/src/db.rs
diff options
context:
space:
mode:
Diffstat (limited to 'server/src/db.rs')
-rw-r--r--server/src/db.rs116
1 files changed, 116 insertions, 0 deletions
diff --git a/server/src/db.rs b/server/src/db.rs
new file mode 100644
index 0000000..2211c42
--- /dev/null
+++ b/server/src/db.rs
@@ -0,0 +1,116 @@
+use tokio;
+use tokio::sync::Mutex;
+use std::{
+ sync::Arc,
+ ops::DerefMut,
+ time::SystemTime,
+ fs,
+ process,
+};
+use sqlx::{
+ ConnectOptions,
+ SqliteConnection,
+ Row,
+};
+use sqlx::sqlite::{
+ SqliteConnectOptions,
+ SqliteJournalMode,
+};
+use rand::RngCore;
+use argon2::password_hash::rand_core::OsRng;
+
+use crate::constants::*;
+use crate::util::base64_encode;
+
+async fn initialise_schema(conn: &mut SqliteConnection) {
+ match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await {
+ Ok(_) => {},
+ Err(err) => {
+ eprintln!("Error: Could not initialise database schema: {err}");
+ process::exit(1);
+ },
+ }
+}
+
+pub async fn open() -> SqliteConnection {
+ let opts = SqliteConnectOptions::new()
+ .filename(DB_FILE_NAME)
+ .create_if_missing(true)
+ .journal_mode(SqliteJournalMode::Wal)
+ .pragma("foreign_keys", "on");
+
+ let mut conn = match opts.connect().await {
+ Ok(conn) => conn,
+ Err(err) => {
+ eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}");
+ process::exit(1);
+ }
+ };
+
+ match sqlx::query("select version from meta").fetch_one(&mut conn).await {
+ Ok(row) => {
+ let ver: i64 = row.get(0);
+ if ver == DB_SCHEMA_VERSION {
+ conn
+ } else {
+ eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}");
+ process::exit(1);
+ }
+ },
+ Err(err) => {
+ eprintln!("Error: Failed querying database meta version: {err}");
+ eprintln!("Initialising schema.");
+ initialise_schema(&mut conn).await;
+ conn
+ },
+ }
+}
+
+pub type DB = Arc<Mutex<SqliteConnection>>;
+
+pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> {
+ let mut conn = db.lock().await;
+ match sqlx::query("insert into Users (username, passhash) values ($1, $2)")
+ .bind(username)
+ .bind(passhash)
+ .execute(conn.deref_mut())
+ .await {
+ Ok(_) => Ok(()),
+ Err(_) => Err("User already exists".to_string()),
+ }
+}
+
+pub async fn get_passhash(db: DB, username: &str) -> Result<String, ()> {
+ let mut conn = db.lock().await;
+ match sqlx::query("select passhash from Users where username = $1")
+ .bind(username)
+ .fetch_optional(conn.deref_mut())
+ .await {
+ Ok(Some(row)) => Ok(row.get(0)),
+ Ok(None) => Err(()),
+ Err(err) => {
+ eprintln!("db_get_passhash: sqlx error: {err}");
+ Err(())
+ }
+ }
+}
+
+fn generate_login_token() -> String {
+ let mut bytes = [0u8; 32];
+ OsRng.fill_bytes(&mut bytes);
+ base64_encode(&bytes)
+}
+
+pub async fn create_login_token(db: DB, username: &str) -> Result<String, ()> {
+ let mut conn = db.lock().await;
+ let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
+ let token = generate_login_token();
+ match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)")
+ .bind(username)
+ .bind(&token)
+ .bind(now as i64 + LOGIN_TOKEN_EXPIRE_SECS)
+ .execute(conn.deref_mut()).await {
+ Ok(_) => Ok(token),
+ Err(_) => Err(()),
+ }
+}