summaryrefslogtreecommitdiff
path: root/server/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'server/src/main.rs')
-rw-r--r--server/src/main.rs238
1 files changed, 232 insertions, 6 deletions
diff --git a/server/src/main.rs b/server/src/main.rs
index 40eb9c8..052cee2 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -1,13 +1,239 @@
use tokio;
-use warp::Filter;
+use tokio::sync::Mutex;
+use std::{
+ sync::Arc,
+ ops::DerefMut,
+ time::SystemTime,
+ fs,
+ process,
+};
+use sqlx::{
+ ConnectOptions,
+ SqliteConnection,
+ Row,
+};
+use sqlx::sqlite::{
+ SqliteConnectOptions,
+ SqliteJournalMode,
+};
+use warp::{Filter, Reply, Rejection};
+use argon2::{
+ password_hash::{
+ rand_core::OsRng,
+ PasswordHash, PasswordHasher, PasswordVerifier, SaltString
+ },
+ Argon2
+};
+use serde::Deserialize;
+use rand::RngCore;
+
+const DB_FILE_NAME: &str = "notes.db";
+const DB_SCHEMA_VERSION: i64 = 1;
+
+fn base64_encode(bytes: &[u8]) -> String {
+ const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+ let len = (bytes.len() + 2) / 3 * 4;
+ let mut dest = Vec::with_capacity(len);
+ dest.resize(len, 0u8);
+ for i in 0 .. bytes.len() / 3 {
+ dest[4 * i + 0] = ALPHABET[(bytes[3 * i + 0] & 0x3f) as usize];
+ dest[4 * i + 1] = ALPHABET[(((bytes[3 * i + 0] & 0xc0) >> 6) | ((bytes[3 * i + 1] & 0x0f) << 2)) as usize];
+ dest[4 * i + 2] = ALPHABET[(((bytes[3 * i + 1] & 0xf0) >> 4) | ((bytes[3 * i + 2] & 0x03) << 4)) as usize];
+ dest[4 * i + 3] = ALPHABET[((bytes[3 * i + 1] & 0xfc) >> 2) as usize];
+ }
+ let last = bytes.len() / 3;
+ match bytes.len() % 3 {
+ 1 => {
+ dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize];
+ dest[4 * last + 1] = ALPHABET[((bytes[3 * last + 0] & 0xc0) >> 6) as usize];
+ dest[4 * last + 2] = b'=';
+ dest[4 * last + 3] = b'=';
+ }
+ 2 => {
+ dest[4 * last + 0] = ALPHABET[(bytes[3 * last + 0] & 0x3f) as usize];
+ dest[4 * last + 1] = ALPHABET[(((bytes[3 * last + 0] & 0xc0) >> 6) | ((bytes[3 * last + 1] & 0x0f) << 2)) as usize];
+ dest[4 * last + 2] = ALPHABET[((bytes[3 * last + 1] & 0xf0) >> 4) as usize];
+ dest[4 * last + 3] = b'=';
+ }
+ _0 => {}
+ }
+ String::from_utf8(dest).unwrap()
+}
+
+async fn db_initialise_schema(conn: &mut SqliteConnection) {
+ match sqlx::query(&fs::read_to_string("schema.sql").unwrap()).execute(conn).await {
+ Ok(_) => {},
+ Err(err) => {
+ eprintln!("Error: Could not initialise database schema: {err}");
+ process::exit(1);
+ },
+ }
+}
+
+async fn open_db() -> SqliteConnection {
+ let opts = SqliteConnectOptions::new()
+ .filename(DB_FILE_NAME)
+ .create_if_missing(true)
+ .journal_mode(SqliteJournalMode::Wal)
+ .pragma("foreign_keys", "on");
+
+ let mut conn = match opts.connect().await {
+ Ok(conn) => conn,
+ Err(err) => {
+ eprintln!("Error: Could not open database file '{DB_FILE_NAME}': {err}");
+ process::exit(1);
+ }
+ };
+
+ match sqlx::query("select version from meta").fetch_one(&mut conn).await {
+ Ok(row) => {
+ let ver: i64 = row.get(0);
+ if ver == DB_SCHEMA_VERSION {
+ conn
+ } else {
+ eprintln!("Error: Database schema version {ver} but application schema version {DB_SCHEMA_VERSION}");
+ process::exit(1);
+ }
+ },
+ Err(err) => {
+ eprintln!("Error: Failed querying database meta version: {err}");
+ eprintln!("Initialising schema.");
+ db_initialise_schema(&mut conn).await;
+ conn
+ },
+ }
+}
+
+type DB = Arc<Mutex<SqliteConnection>>;
+
+async fn db_register_account(db: DB, username: &str, passhash: &str) -> Result<(), String> {
+ let mut conn = db.lock().await;
+ match sqlx::query("insert into Users (username, passhash) values ($1, $2)")
+ .bind(username)
+ .bind(passhash)
+ .execute(conn.deref_mut())
+ .await {
+ Ok(_) => Ok(()),
+ Err(_) => Err("User already exists".to_string()),
+ }
+}
+
+async fn db_get_passhash(db: DB, username: &str) -> Result<String, ()> {
+ let mut conn = db.lock().await;
+ match sqlx::query("select passhash from Users where username = $1")
+ .bind(username)
+ .fetch_optional(conn.deref_mut())
+ .await {
+ Ok(Some(row)) => Ok(row.get(0)),
+ Ok(None) => Err(()),
+ Err(err) => {
+ eprintln!("db_get_passhash: sqlx error: {err}");
+ Err(())
+ }
+ }
+}
+
+fn generate_login_token() -> String {
+ let mut bytes = [0u8; 32];
+ OsRng.fill_bytes(&mut bytes);
+ base64_encode(&bytes)
+}
+
+async fn db_create_login_token(db: DB, username: &str) -> Result<String, ()> {
+ let mut conn = db.lock().await;
+ let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
+ let token = generate_login_token();
+ match sqlx::query("insert into Logins (user, token, expire) values ($1, $2, $3)")
+ .bind(username)
+ .bind(&token)
+ .bind(now as i64 + 8 * 24 * 3600) // valid for 8 days
+ .execute(conn.deref_mut()).await {
+ Ok(_) => Ok(token),
+ Err(_) => Err(()),
+ }
+}
+
+async fn handle_ping() -> Result<String, Rejection> {
+ return Ok("pong".to_string())
+}
+
+#[derive(Deserialize)]
+struct RegisterReq {
+ username: String,
+ password: String,
+}
+
+type Response = Result<Box<dyn Reply>, Rejection>;
+
+macro_rules! mk_bad_request {
+ ($res:expr) => { Ok(Box::new(warp::reply::with_status($res, warp::http::StatusCode::BAD_REQUEST))) }
+}
+
+macro_rules! mk_not_found {
+ ($res:expr) => { Ok(Box::new(warp::reply::with_status($res, warp::http::StatusCode::NOT_FOUND))) }
+}
+
+macro_rules! mk_server_err {
+ () => { Ok(Box::new(warp::reply::with_status("Internal server error", warp::http::StatusCode::INTERNAL_SERVER_ERROR))) }
+}
+
+async fn handle_register(db: DB, req: RegisterReq) -> Response {
+ let salt = SaltString::generate(&mut OsRng);
+ let hash = match Argon2::default().hash_password(req.password.as_bytes(), &salt) {
+ Ok(hash) => hash.to_string(),
+ Err(_) => return mk_bad_request!("bad request"),
+ };
+ match db_register_account(db, &req.username, &hash.to_string()).await {
+ Ok(()) => Ok(Box::new("Registered")),
+ Err(err) => mk_bad_request!(err),
+ }
+}
+
+async fn handle_login(db: DB, req: RegisterReq) -> Result<Box<dyn Reply>, Rejection> {
+ let passhash = match db_get_passhash(db.clone(), &req.username).await {
+ Ok(passhash) => passhash,
+ Err(()) => return mk_not_found!("User not found"),
+ };
+ let parsed_hash = match PasswordHash::new(&passhash) {
+ Ok(parsed_hash) => parsed_hash,
+ Err(err) => {
+ eprintln!("Could not parse password hash: {err}");
+ return mk_server_err!();
+ }
+ };
+ if let Err(_) = Argon2::default().verify_password(req.password.as_bytes(), &parsed_hash) {
+ return Ok(Box::new(warp::reply::with_status("Incorrect password", warp::http::StatusCode::UNAUTHORIZED)));
+ }
+ match db_create_login_token(db, &req.username).await {
+ Ok(token) => Ok(Box::new(token)),
+ Err(()) => {
+ eprintln!("Failed inserting login token for user '{0}'", &req.username);
+ mk_server_err!()
+ }
+ }
+}
+
+macro_rules! db_handler1 {
+ ($db:expr, $handler:ident) => { { let db2 = $db.clone(); move |a| $handler(db2.clone(), a) } }
+}
#[tokio::main]
async fn main() {
- let hello = warp::path!("hello" / String)
- .and_then(|name| format!("Hello, {}!", name))
- .or(|rip| _);
+ let db: DB = Arc::new(Mutex::new(open_db().await));
+ println!("Opened database at {DB_FILE_NAME}.");
+
+ let router =
+ (warp::get().and(warp::path!("ping"))
+ .and_then(handle_ping))
+ .or(warp::post().and(warp::path!("register"))
+ .and(warp::body::json())
+ .and_then(db_handler1!(db, handle_register)))
+ .or(warp::post().and(warp::path!("login"))
+ .and(warp::body::json())
+ .and_then(db_handler1!(db, handle_login)));
- warp::serve(hello)
- .run(([127, 0, 0, 1], 3030))
+ warp::serve(router)
+ .run(([0, 0, 0, 0], 8775))
.await;
}