From c04ced6609a90ddebf15c6337c7761a0697a3497 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 Jan 2025 23:08:11 +0100 Subject: server: Create new empty files --- server/src/db.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++----- server/src/main.rs | 58 +++++++++++++++++++++++++++++++++++++++------- server/src/path.rs | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 14 deletions(-) create mode 100644 server/src/path.rs diff --git a/server/src/db.rs b/server/src/db.rs index bad748d..9efb130 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -20,6 +20,7 @@ use rand::RngCore; use argon2::password_hash::rand_core::OsRng; use crate::constants::*; +use crate::path::Path; use crate::util::base64_encode; async fn initialise_schema(conn: &mut SqliteConnection) { @@ -32,7 +33,9 @@ async fn initialise_schema(conn: &mut SqliteConnection) { } } -pub async fn open() -> SqliteConnection { +pub type DB = Arc>; + +pub async fn open() -> DB { let opts = SqliteConnectOptions::new() .filename(DB_FILE_NAME) .create_if_missing(true) @@ -51,7 +54,7 @@ pub async fn open() -> SqliteConnection { Ok(row) => { let ver: i64 = row.get(0); if ver == DB_SCHEMA_VERSION { - conn + Arc::new(Mutex::new(conn)) } else { eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}"); process::exit(1); @@ -61,13 +64,11 @@ pub async fn open() -> SqliteConnection { eprintln!("Error: Failed querying database meta version: {err}"); eprintln!("Initialising schema."); initialise_schema(&mut conn).await; - conn + Arc::new(Mutex::new(conn)) }, } } -pub type DB = Arc>; - pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> { let mut conn = db.lock().await; match sqlx::query("insert into Users (username, passhash) values ($1, $2)") @@ -109,7 +110,7 @@ pub async fn create_login_token(db: DB, username: &str) -> Result { let mut conn = db.lock().await; let now = current_time(); let token = generate_login_token(); - match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)") + match sqlx::query("insert into Logins (user, token, expire) values ((select id from Users where username = $1), $2, $3)") .bind(username) .bind(&token) .bind(now + LOGIN_TOKEN_EXPIRE_SECS) @@ -162,3 +163,58 @@ pub async fn maybe_refresh_token(db: DB, token: &str) -> Result<(), String> { } } } + +#[derive(Copy, Clone)] +pub struct UserId { + id: i64, +} + +pub async fn check_login(db: DB, token: &str) -> Result { + let mut conn = db.lock().await; + match sqlx::query("select user from Logins where token = $1") + .bind(token) + .fetch_optional(conn.deref_mut()).await { + Ok(Some(row)) => Ok(UserId { id: row.get(0) }), + Ok(None) => Err(()), + Err(err) => { + eprintln!("check_login: err = {err}"); + Err(()) + } + } +} + +/// Returns ID of the file. +pub async fn file_lookup(db: DB, user: UserId, path: &Path<'_>) -> Option { + let mut conn = db.lock().await; + match sqlx::query("select id from Files where owner = $1 and path = $2") + .bind(user.id) + .bind(path.join()) + .fetch_optional(conn.deref_mut()).await { + Ok(Some(row)) => Some(row.get(0)), + Ok(None) => None, + Err(err) => { + eprintln!("file_exists: err = {err}"); + None + } + } +} + +/// Returns ID of created file. +pub async fn file_create_empty(db: DB, user: UserId, path: &Path<'_>) -> Result { + if let Some((parent, _)) = path.without_last() { + if file_lookup(db.clone(), user, &parent).await.is_none() { + return Err("Parent does not exist".to_string()); + } + } + + let mut conn = db.lock().await; + match sqlx::query("insert into Files (owner, path) values ($1, $2)") + .bind(user.id) + .bind(path.join()) + .execute(conn.deref_mut()).await { + Ok(res) => Ok(res.last_insert_rowid()), + Err(_) => { // let's assume it was a collision + Err("File already exists".to_string()) + } + } +} diff --git a/server/src/main.rs b/server/src/main.rs index 3a1fd52..88eb903 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -2,10 +2,6 @@ // - Have a task that cleans up expired logins once in a while (once every day?) use tokio; -use tokio::sync::Mutex; -use std::{ - sync::Arc, -}; use warp::{Filter, Reply, Rejection}; use argon2::{ password_hash::{ @@ -14,11 +10,12 @@ use argon2::{ }, Argon2 }; -use serde::Deserialize; +use serde::{Serialize, Deserialize}; mod constants; use constants::*; mod util; mod db; use db::DB; +mod path; use path::Path; macro_rules! mk_bad_request { ($res:expr) => { Ok(Box::new(warp::reply::with_status($res, warp::http::StatusCode::BAD_REQUEST))) } @@ -28,6 +25,10 @@ macro_rules! mk_not_found { ($res:expr) => { Ok(Box::new(warp::reply::with_status($res, warp::http::StatusCode::NOT_FOUND))) } } +macro_rules! mk_unauthorized { + ($msg:expr) => { Ok(Box::new(warp::reply::with_status($msg, warp::http::StatusCode::UNAUTHORIZED))) } +} + macro_rules! mk_server_err { () => { Ok(Box::new(warp::reply::with_status("Internal server error", warp::http::StatusCode::INTERNAL_SERVER_ERROR))) } } @@ -75,7 +76,7 @@ async fn handle_login(db: DB, req: RegisterReq) -> Response { } }; if let Err(_) = Argon2::default().verify_password(req.password.as_bytes(), &parsed_hash) { - return Ok(Box::new(warp::reply::with_status("Incorrect password", warp::http::StatusCode::UNAUTHORIZED))); + return mk_unauthorized!("Incorrect password"); } match db::create_login_token(db, &req.username).await { Ok(token) => Ok(Box::new(token)), @@ -91,13 +92,50 @@ async fn handle_logout(db: DB, token: String) -> Response { Ok(Box::new("Logged out")) } +macro_rules! check_login { + ($db:expr, $token:expr) => { + match db::check_login($db.clone(), $token).await { + Ok(user) => user, + Err(()) => return mk_unauthorized!("Not logged in"), + } + } +} + +#[derive(Deserialize)] +struct FileCreateReq { + path: String, +} + +#[derive(Serialize)] +struct FileCreateRes { + id: i64, +} + +async fn handle_file_create(db: DB, token: String, req: FileCreateReq) -> Response { + let user = check_login!(db, &token); + + let path = match Path::split(&req.path) { + Some(path) => path, + None => return mk_bad_request!("Invalid path"), + }; + + match db::file_create_empty(db.clone(), user, &path).await { + Ok(id) => Ok(Box::new(warp::reply::json(&FileCreateRes { id }))), + Err(err) => mk_bad_request!(err), + } +} + macro_rules! db_handler1 { ($db:expr, $handler:ident) => { { let db2 = $db.clone(); move |a| $handler(db2.clone(), a) } } } +macro_rules! db_handler2 { + ($db:expr, $handler:ident) => { { let db2 = $db.clone(); move |a,b| $handler(db2.clone(), a, b) } } +} + #[tokio::main] async fn main() { - let db: DB = Arc::new(Mutex::new(db::open().await)); + let db: DB = db::open().await; println!("Opened database at {DB_FILE_NAME}."); let use_login_token = warp::header::("x-kaasnoot-token"); @@ -115,7 +153,11 @@ async fn main() { .and_then(db_handler1!(db, handle_login))) .or(warp::post().and(warp::path!("logout")) .and(use_login_token) - .and_then(db_handler1!(db, handle_logout))); + .and_then(db_handler1!(db, handle_logout))) + .or(warp::put().and(warp::path!("file")) + .and(use_login_token) + .and(warp::body::json()) + .and_then(db_handler2!(db, handle_file_create))); warp::serve(router) .run(([0, 0, 0, 0], 8775)) diff --git a/server/src/path.rs b/server/src/path.rs new file mode 100644 index 0000000..17aa263 --- /dev/null +++ b/server/src/path.rs @@ -0,0 +1,61 @@ +// Invariants: +// - 'orig' is a valid path, and in particular, the '/'-split components are non-empty; +// - 'comps' describes a non-empty prefix of 'orig'. +#[derive(Debug)] +pub struct Path<'a> { + orig: &'a str, + comps: Vec<(usize, &'a str)>, // (offset, component) +} + +impl<'a> Path<'a> { + /// Returns the split path if the path is valid. A valid path has the following requirements: + /// * It does not contain unicode control characters; + /// * When split on '/', the resulting components are all non-empty and neither start nor end + /// with unicode whitespace characters. + pub fn split(s: &str) -> Option { + let mut comps = Vec::new(); + let mut start = 0; // of current component + for (i, c) in s.char_indices() { + match c { + '/' if i == start => return None, // empty component + '/' => { + let comp = &s[start..i]; + if comp.starts_with(|c2: char| c2.is_whitespace()) || + comp.ends_with(|c2: char| c2.is_whitespace()) { + return None; + } + comps.push((start, comp)); + start = i + 1; + }, + _ if c.is_control() => return None, + _ => {}, // include in current component + } + } + + // check and add the last component + if start == s.len() { + return None; // slash at end of input + } + comps.push((start, &s[start..s.len()])); + + Some(Path { orig: s, comps }) + } + + pub fn without_last(&self) -> Option<(Path, &str)> { + let n = self.comps.len(); + if n > 1 { + Some((Path { orig: self.orig, comps: self.comps[0 .. n - 1].into() }, self.comps[n - 1].1)) + } else { + None + } + } + + pub fn join(&self) -> &str { + if self.comps.len() == 0 { + "" + } else { + let (lastoff, lastcomp) = self.comps[self.comps.len() - 1]; + &self.orig[0 .. lastoff + lastcomp.len()] + } + } +} -- cgit v1.2.3-70-g09d2