summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/db.rs68
-rw-r--r--server/src/main.rs58
-rw-r--r--server/src/path.rs61
3 files changed, 173 insertions, 14 deletions
diff --git a/server/src/db.rs b/server/src/db.rs
index bad748d..9efb130 100644
--- a/server/src/db.rs
+++ b/server/src/db.rs
@@ -20,6 +20,7 @@ use rand::RngCore;
use argon2::password_hash::rand_core::OsRng;
use crate::constants::*;
+use crate::path::Path;
use crate::util::base64_encode;
async fn initialise_schema(conn: &mut SqliteConnection) {
@@ -32,7 +33,9 @@ async fn initialise_schema(conn: &mut SqliteConnection) {
}
}
-pub async fn open() -> SqliteConnection {
+pub type DB = Arc<Mutex<SqliteConnection>>;
+
+pub async fn open() -> DB {
let opts = SqliteConnectOptions::new()
.filename(DB_FILE_NAME)
.create_if_missing(true)
@@ -51,7 +54,7 @@ pub async fn open() -> SqliteConnection {
Ok(row) => {
let ver: i64 = row.get(0);
if ver == DB_SCHEMA_VERSION {
- conn
+ Arc::new(Mutex::new(conn))
} else {
eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}");
process::exit(1);
@@ -61,13 +64,11 @@ pub async fn open() -> SqliteConnection {
eprintln!("Error: Failed querying database meta version: {err}");
eprintln!("Initialising schema.");
initialise_schema(&mut conn).await;
- conn
+ Arc::new(Mutex::new(conn))
},
}
}
-pub type DB = Arc<Mutex<SqliteConnection>>;
-
pub async fn register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> {
let mut conn = db.lock().await;
match sqlx::query("insert into Users (username, passhash) values ($1, $2)")
@@ -109,7 +110,7 @@ pub async fn create_login_token(db: DB, username: &str) -> Result<String, ()> {
let mut conn = db.lock().await;
let now = current_time();
let token = generate_login_token();
- match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)")
+ match sqlx::query("insert into Logins (user, token, expire) values ((select id from Users where username = $1), $2, $3)")
.bind(username)
.bind(&token)
.bind(now + LOGIN_TOKEN_EXPIRE_SECS)
@@ -162,3 +163,58 @@ pub async fn maybe_refresh_token(db: DB, token: &str) -> Result<(), String> {
}
}
}
+
+#[derive(Copy, Clone)]
+pub struct UserId {
+ id: i64,
+}
+
+pub async fn check_login(db: DB, token: &str) -> Result<UserId, ()> {
+ let mut conn = db.lock().await;
+ match sqlx::query("select user from Logins where token = $1")
+ .bind(token)
+ .fetch_optional(conn.deref_mut()).await {
+ Ok(Some(row)) => Ok(UserId { id: row.get(0) }),
+ Ok(None) => Err(()),
+ Err(err) => {
+ eprintln!("check_login: err = {err}");
+ Err(())
+ }
+ }
+}
+
+/// Returns ID of the file.
+pub async fn file_lookup(db: DB, user: UserId, path: &Path<'_>) -> Option<i64> {
+ let mut conn = db.lock().await;
+ match sqlx::query("select id from Files where owner = $1 and path = $2")
+ .bind(user.id)
+ .bind(path.join())
+ .fetch_optional(conn.deref_mut()).await {
+ Ok(Some(row)) => Some(row.get(0)),
+ Ok(None) => None,
+ Err(err) => {
+ eprintln!("file_exists: err = {err}");
+ None
+ }
+ }
+}
+
+/// Returns ID of created file.
+pub async fn file_create_empty(db: DB, user: UserId, path: &Path<'_>) -> Result<i64, String> {
+ if let Some((parent, _)) = path.without_last() {
+ if file_lookup(db.clone(), user, &parent).await.is_none() {
+ return Err("Parent does not exist".to_string());
+ }
+ }
+
+ let mut conn = db.lock().await;
+ match sqlx::query("insert into Files (owner, path) values ($1, $2)")
+ .bind(user.id)
+ .bind(path.join())
+ .execute(conn.deref_mut()).await {
+ Ok(res) => Ok(res.last_insert_rowid()),
+ Err(_) => { // let's assume it was a collision
+ Err("File already exists".to_string())
+ }
+ }
+}
diff --git a/server/src/main.rs b/server/src/main.rs
index 3a1fd52..88eb903 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -2,10 +2,6 @@
// - Have a task that cleans up expired logins once in a while (once every day?)
use tokio;
-use tokio::sync::Mutex;
-use std::{
- sync::Arc,
-};
use warp::{Filter, Reply, Rejection};
use argon2::{
password_hash::{
@@ -14,11 +10,12 @@ use argon2::{
},
Argon2
};
-use serde::Deserialize;
+use serde::{Serialize, Deserialize};
mod constants; use constants::*;
mod util;
mod db; use db::DB;
+mod path; use path::Path;
macro_rules! mk_bad_request {
($res:expr) => { Ok(Box::new(warp::reply::with_status($res, warp::http::StatusCode::BAD_REQUEST))) }
@@ -28,6 +25,10 @@ macro_rules! mk_not_found {
($res:expr) => { Ok(Box::new(warp::reply::with_status($res, warp::http::StatusCode::NOT_FOUND))) }
}
+macro_rules! mk_unauthorized {
+ ($msg:expr) => { Ok(Box::new(warp::reply::with_status($msg, warp::http::StatusCode::UNAUTHORIZED))) }
+}
+
macro_rules! mk_server_err {
() => { Ok(Box::new(warp::reply::with_status("Internal server error", warp::http::StatusCode::INTERNAL_SERVER_ERROR))) }
}
@@ -75,7 +76,7 @@ async fn handle_login(db: DB, req: RegisterReq) -> Response {
}
};
if let Err(_) = Argon2::default().verify_password(req.password.as_bytes(), &parsed_hash) {
- return Ok(Box::new(warp::reply::with_status("Incorrect password", warp::http::StatusCode::UNAUTHORIZED)));
+ return mk_unauthorized!("Incorrect password");
}
match db::create_login_token(db, &req.username).await {
Ok(token) => Ok(Box::new(token)),
@@ -91,13 +92,50 @@ async fn handle_logout(db: DB, token: String) -> Response {
Ok(Box::new("Logged out"))
}
+macro_rules! check_login {
+ ($db:expr, $token:expr) => {
+ match db::check_login($db.clone(), $token).await {
+ Ok(user) => user,
+ Err(()) => return mk_unauthorized!("Not logged in"),
+ }
+ }
+}
+
+#[derive(Deserialize)]
+struct FileCreateReq {
+ path: String,
+}
+
+#[derive(Serialize)]
+struct FileCreateRes {
+ id: i64,
+}
+
+async fn handle_file_create(db: DB, token: String, req: FileCreateReq) -> Response {
+ let user = check_login!(db, &token);
+
+ let path = match Path::split(&req.path) {
+ Some(path) => path,
+ None => return mk_bad_request!("Invalid path"),
+ };
+
+ match db::file_create_empty(db.clone(), user, &path).await {
+ Ok(id) => Ok(Box::new(warp::reply::json(&FileCreateRes { id }))),
+ Err(err) => mk_bad_request!(err),
+ }
+}
+
macro_rules! db_handler1 {
($db:expr, $handler:ident) => { { let db2 = $db.clone(); move |a| $handler(db2.clone(), a) } }
}
+macro_rules! db_handler2 {
+ ($db:expr, $handler:ident) => { { let db2 = $db.clone(); move |a,b| $handler(db2.clone(), a, b) } }
+}
+
#[tokio::main]
async fn main() {
- let db: DB = Arc::new(Mutex::new(db::open().await));
+ let db: DB = db::open().await;
println!("Opened database at {DB_FILE_NAME}.");
let use_login_token = warp::header::<String>("x-kaasnoot-token");
@@ -115,7 +153,11 @@ async fn main() {
.and_then(db_handler1!(db, handle_login)))
.or(warp::post().and(warp::path!("logout"))
.and(use_login_token)
- .and_then(db_handler1!(db, handle_logout)));
+ .and_then(db_handler1!(db, handle_logout)))
+ .or(warp::put().and(warp::path!("file"))
+ .and(use_login_token)
+ .and(warp::body::json())
+ .and_then(db_handler2!(db, handle_file_create)));
warp::serve(router)
.run(([0, 0, 0, 0], 8775))
diff --git a/server/src/path.rs b/server/src/path.rs
new file mode 100644
index 0000000..17aa263
--- /dev/null
+++ b/server/src/path.rs
@@ -0,0 +1,61 @@
+// Invariants:
+// - 'orig' is a valid path, and in particular, the '/'-split components are non-empty;
+// - 'comps' describes a non-empty prefix of 'orig'.
+#[derive(Debug)]
+pub struct Path<'a> {
+ orig: &'a str,
+ comps: Vec<(usize, &'a str)>, // (offset, component)
+}
+
+impl<'a> Path<'a> {
+ /// Returns the split path if the path is valid. A valid path has the following requirements:
+ /// * It does not contain unicode control characters;
+ /// * When split on '/', the resulting components are all non-empty and neither start nor end
+ /// with unicode whitespace characters.
+ pub fn split(s: &str) -> Option<Path> {
+ let mut comps = Vec::new();
+ let mut start = 0; // of current component
+ for (i, c) in s.char_indices() {
+ match c {
+ '/' if i == start => return None, // empty component
+ '/' => {
+ let comp = &s[start..i];
+ if comp.starts_with(|c2: char| c2.is_whitespace()) ||
+ comp.ends_with(|c2: char| c2.is_whitespace()) {
+ return None;
+ }
+ comps.push((start, comp));
+ start = i + 1;
+ },
+ _ if c.is_control() => return None,
+ _ => {}, // include in current component
+ }
+ }
+
+ // check and add the last component
+ if start == s.len() {
+ return None; // slash at end of input
+ }
+ comps.push((start, &s[start..s.len()]));
+
+ Some(Path { orig: s, comps })
+ }
+
+ pub fn without_last(&self) -> Option<(Path, &str)> {
+ let n = self.comps.len();
+ if n > 1 {
+ Some((Path { orig: self.orig, comps: self.comps[0 .. n - 1].into() }, self.comps[n - 1].1))
+ } else {
+ None
+ }
+ }
+
+ pub fn join(&self) -> &str {
+ if self.comps.len() == 0 {
+ ""
+ } else {
+ let (lastoff, lastcomp) = self.comps[self.comps.len() - 1];
+ &self.orig[0 .. lastoff + lastcomp.len()]
+ }
+ }
+}