postgres via sqlx - workable?
This commit is contained in:
@@ -7,12 +7,12 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||
use rusqlite::OptionalExtension;
|
||||
use sqlx::{PgConnection, Row};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
ISE_MSG,
|
||||
database::{self, DatabaseError},
|
||||
database::DatabaseError,
|
||||
users::{
|
||||
User,
|
||||
auth::{
|
||||
@@ -53,8 +53,8 @@ impl IntoResponse for AuthError {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<rusqlite::Error> for AuthError {
|
||||
fn from(value: rusqlite::Error) -> Self {
|
||||
impl From<sqlx::Error> for AuthError {
|
||||
fn from(value: sqlx::Error) -> Self {
|
||||
AuthError::DatabaseError(DatabaseError::from(value))
|
||||
}
|
||||
}
|
||||
@@ -122,21 +122,27 @@ impl<'a> AuthScheme<'a> {
|
||||
}
|
||||
|
||||
impl UserAuthenticate for User {
|
||||
fn authenticate(headers: &HeaderMap) -> Result<Option<User>, AuthError> {
|
||||
async fn authenticate(
|
||||
conn: &mut PgConnection,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<Option<User>, AuthError> {
|
||||
let (basic_auth, bearer_auth) = auth_common(headers);
|
||||
|
||||
match (basic_auth, bearer_auth) {
|
||||
(Some(creds), _) => authenticate_basic(&creds),
|
||||
(None, Some(token)) => authenticate_bearer(&token),
|
||||
(Some(creds), _) => authenticate_basic(conn, &creds).await,
|
||||
(None, Some(token)) => authenticate_bearer(conn, &token).await,
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl SessionAuthenticate for Session {
|
||||
fn authenticate(headers: &HeaderMap) -> Result<Option<Session>, AuthError> {
|
||||
async fn authenticate(
|
||||
conn: &mut PgConnection,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<Option<Session>, AuthError> {
|
||||
let (_, bearer_auth) = auth_common(headers);
|
||||
if let Some(token) = bearer_auth {
|
||||
authenticate_bearer_with_session(&token)
|
||||
authenticate_bearer_with_session(conn, &token).await
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
@@ -181,52 +187,71 @@ fn auth_common(headers: &HeaderMap) -> (Option<String>, Option<String>) {
|
||||
(basic_auth, bearer_auth)
|
||||
}
|
||||
|
||||
fn authenticate_basic(credentials: &str) -> Result<Option<User>, AuthError> {
|
||||
async fn authenticate_basic(
|
||||
conn: &mut PgConnection,
|
||||
credentials: &str,
|
||||
) -> Result<Option<User>, AuthError> {
|
||||
let decoded = BASE64_STANDARD.decode(credentials)?;
|
||||
let credentials_str = String::from_utf8(decoded)?;
|
||||
|
||||
let Some((handle, password)) = credentials_str.split_once(':') else {
|
||||
return Err(AuthError::InvalidFormat);
|
||||
};
|
||||
authenticate_via_credentials(handle, password)
|
||||
authenticate_via_credentials(conn, handle, password).await
|
||||
}
|
||||
pub fn authenticate_via_credentials(
|
||||
|
||||
pub async fn authenticate_via_credentials(
|
||||
conn: &mut PgConnection,
|
||||
handle: &str,
|
||||
password: &str,
|
||||
) -> Result<Option<User>, AuthError> {
|
||||
let conn = database::conn()?;
|
||||
let user: Option<(Uuid, Option<String>)> = conn
|
||||
.prepare("SELECT id, password FROM users WHERE handle = ?1")?
|
||||
.query_row([handle], |r| Ok((r.get(0)?, r.get(1)?)))
|
||||
.optional()?;
|
||||
let row = sqlx::query("SELECT id, password FROM users WHERE handle = $1")
|
||||
.bind(handle)
|
||||
.fetch_optional(&mut *conn)
|
||||
.await?;
|
||||
|
||||
match user {
|
||||
Some((id, Some(passhash))) => match User::match_hash_password(password, &passhash)? {
|
||||
true => Ok(Some(User::get_by_id(&conn, id)?)),
|
||||
false => Err(AuthError::InvalidCredentials),
|
||||
},
|
||||
_ => {
|
||||
match row {
|
||||
Some(r) => {
|
||||
let id: Uuid = r.try_get("id")?;
|
||||
let passhash: Option<String> = r.try_get("password")?;
|
||||
match passhash {
|
||||
Some(p) => match User::match_hash_password(password, &p)? {
|
||||
true => Ok(Some(User::get_by_id(conn, id).await?)),
|
||||
false => Err(AuthError::InvalidCredentials),
|
||||
},
|
||||
None => {
|
||||
let _ = User::match_hash_password(DUMMY_PASSWORD, &DUMMY_PASSWORD_PHC)?;
|
||||
Err(AuthError::InvalidCredentials)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let _ = User::match_hash_password(DUMMY_PASSWORD, &DUMMY_PASSWORD_PHC)?;
|
||||
Err(AuthError::InvalidCredentials)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate_bearer(token: &str) -> Result<Option<User>, AuthError> {
|
||||
let conn = database::conn().map_err(|e| DatabaseError::from(e))?;
|
||||
let mut s = Session::get_by_token(&conn, token)?;
|
||||
async fn authenticate_bearer(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<User>, AuthError> {
|
||||
let mut s = Session::get_by_token(&mut *conn, token).await?;
|
||||
if s.is_expired_or_revoked() {
|
||||
return Err(AuthError::InvalidCredentials);
|
||||
}
|
||||
s.prolong(&conn)?;
|
||||
Ok(Some(User::get_by_id(&conn, s.user_id)?))
|
||||
s.prolong(&mut *conn).await?;
|
||||
Ok(Some(User::get_by_id(conn, s.user_id).await?))
|
||||
}
|
||||
fn authenticate_bearer_with_session(token: &str) -> Result<Option<Session>, AuthError> {
|
||||
let conn = database::conn().map_err(|e| DatabaseError::from(e))?;
|
||||
let mut s = Session::get_by_token(&conn, token)?;
|
||||
|
||||
async fn authenticate_bearer_with_session(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<Session>, AuthError> {
|
||||
let mut s = Session::get_by_token(&mut *conn, token).await?;
|
||||
if s.is_expired_or_revoked() {
|
||||
return Err(AuthError::InvalidCredentials);
|
||||
}
|
||||
s.prolong(&conn)?;
|
||||
s.prolong(conn).await?;
|
||||
Ok(Some(s))
|
||||
}
|
||||
|
||||
@@ -16,14 +16,22 @@ pub mod implementation;
|
||||
|
||||
pub const COOKIE_NAME: &str = "mnemohash";
|
||||
|
||||
use sqlx::PgConnection;
|
||||
|
||||
pub trait UserAuthenticate {
|
||||
fn authenticate(headers: &HeaderMap) -> Result<Option<User>, AuthError>;
|
||||
async fn authenticate(
|
||||
conn: &mut PgConnection,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<Option<User>, AuthError>;
|
||||
}
|
||||
pub trait UserAuthRequired {
|
||||
fn required(self) -> Result<User, AuthError>;
|
||||
}
|
||||
pub trait SessionAuthenticate {
|
||||
fn authenticate(headers: &HeaderMap) -> Result<Option<Session>, AuthError>;
|
||||
async fn authenticate(
|
||||
conn: &mut PgConnection,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<Option<Session>, AuthError>;
|
||||
}
|
||||
pub trait SessionAuthRequired {
|
||||
fn required(self) -> Result<Session, AuthError>;
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
use std::{fmt::Display, hash::Hash, ops::Deref, str::FromStr};
|
||||
|
||||
use rusqlite::{
|
||||
Result as RusqliteResult,
|
||||
types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
#[serde(into = "String")]
|
||||
#[serde(try_from = "String")]
|
||||
pub struct UserHandle(String);
|
||||
@@ -90,15 +87,3 @@ impl From<UserHandle> for String {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for UserHandle {
|
||||
fn to_sql(&self) -> RusqliteResult<ToSqlOutput<'_>> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for UserHandle {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
UserHandle::from_str(value.as_str()?).map_err(|e| FromSqlError::Other(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
163
src/users/mod.rs
163
src/users/mod.rs
@@ -3,8 +3,8 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use chrono::{DateTime, NaiveDate};
|
||||
use rusqlite::{Connection, OptionalExtension, ffi::SQLITE_CONSTRAINT_UNIQUE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{PgConnection, Row};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
@@ -45,65 +45,87 @@ pub enum UserError {
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn total_count(conn: &Connection) -> Result<i64, UserError> {
|
||||
Ok(conn.query_row("SELECT COUNT(*) FROM users", (), |r| r.get(0))?)
|
||||
pub async fn total_count(conn: &mut PgConnection) -> Result<i64, UserError> {
|
||||
Ok(sqlx::query_scalar("SELECT COUNT(*) FROM users")
|
||||
.fetch_one(conn)
|
||||
.await?)
|
||||
}
|
||||
pub fn get_by_id(conn: &Connection, id: Uuid) -> Result<User, UserError> {
|
||||
let res = conn
|
||||
.prepare("SELECT handle FROM users WHERE id = ?1")?
|
||||
.query_one((&id,), |r| {
|
||||
|
||||
pub async fn get_by_id(conn: &mut PgConnection, id: Uuid) -> Result<User, UserError> {
|
||||
let res = sqlx::query("SELECT handle FROM users WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(conn)
|
||||
.await?;
|
||||
|
||||
match res {
|
||||
Some(r) => {
|
||||
let handle_str: String = r.try_get("handle")?;
|
||||
Ok(User {
|
||||
id,
|
||||
handle: r.get(0)?,
|
||||
handle: UserHandle::new(&handle_str)?,
|
||||
})
|
||||
})
|
||||
.optional()?;
|
||||
match res {
|
||||
Some(u) => Ok(u),
|
||||
}
|
||||
None => Err(UserError::NoUserWithId(id)),
|
||||
}
|
||||
}
|
||||
pub fn get_by_handle(conn: &Connection, handle: UserHandle) -> Result<User, UserError> {
|
||||
let res = conn
|
||||
.prepare("SELECT id, handle FROM users WHERE handle = ?1")?
|
||||
.query_one((&handle,), |r| {
|
||||
Ok(User {
|
||||
id: r.get(0)?,
|
||||
handle: r.get(1)?,
|
||||
})
|
||||
})
|
||||
.optional()?;
|
||||
|
||||
pub async fn get_by_handle(
|
||||
conn: &mut PgConnection,
|
||||
handle: UserHandle,
|
||||
) -> Result<User, UserError> {
|
||||
let res = sqlx::query("SELECT id FROM users WHERE handle = $1")
|
||||
.bind(handle.as_str())
|
||||
.fetch_optional(conn)
|
||||
.await?;
|
||||
|
||||
match res {
|
||||
Some(u) => Ok(u),
|
||||
Some(r) => Ok(User {
|
||||
id: r.try_get("id")?,
|
||||
handle,
|
||||
}),
|
||||
None => Err(UserError::NoUserWithHandle(handle)),
|
||||
}
|
||||
}
|
||||
pub fn get_all(conn: &Connection) -> Result<Vec<User>, UserError> {
|
||||
Ok(conn
|
||||
.prepare("SELECT id, handle FROM users")?
|
||||
.query_map((), |r| {
|
||||
Ok(User {
|
||||
id: r.get(0)?,
|
||||
handle: r.get(1)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<User>, _>>()?)
|
||||
|
||||
pub async fn get_all(conn: &mut PgConnection) -> Result<Vec<User>, UserError> {
|
||||
let rows = sqlx::query("SELECT id, handle FROM users")
|
||||
.fetch_all(conn)
|
||||
.await?;
|
||||
|
||||
let mut users = Vec::with_capacity(rows.len());
|
||||
for r in rows {
|
||||
let handle_str: String = r.try_get("handle")?;
|
||||
users.push(User {
|
||||
id: r.try_get("id")?,
|
||||
handle: UserHandle::new(&handle_str)?,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub fn create(conn: &Connection, handle: UserHandle) -> Result<User, UserError> {
|
||||
pub async fn create(conn: &mut PgConnection, handle: UserHandle) -> Result<User, UserError> {
|
||||
let id = Uuid::now_v7();
|
||||
conn.prepare("INSERT INTO users(id, handle) VALUES (?1, ?2)")?
|
||||
.execute((&id, &handle))?;
|
||||
sqlx::query("INSERT INTO users(id, handle) VALUES ($1, $2)")
|
||||
.bind(id)
|
||||
.bind(handle.as_str())
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Ok(User { id, handle })
|
||||
}
|
||||
|
||||
pub fn set_handle(
|
||||
pub async fn set_handle(
|
||||
&mut self,
|
||||
conn: &Connection,
|
||||
conn: &mut PgConnection,
|
||||
new_handle: UserHandle,
|
||||
) -> Result<(), UserError> {
|
||||
conn.prepare("UPDATE users SET handle = ?1 WHERE id = ?2")?
|
||||
.execute((&new_handle, self.id))?;
|
||||
sqlx::query("UPDATE users SET handle = $1 WHERE id = $2")
|
||||
.bind(new_handle.as_str())
|
||||
.bind(self.id)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
self.handle = new_handle;
|
||||
Ok(())
|
||||
}
|
||||
@@ -118,21 +140,26 @@ impl User {
|
||||
|
||||
// DANGEROUS: AUTH
|
||||
impl User {
|
||||
pub fn set_password(
|
||||
pub async fn set_password(
|
||||
&mut self,
|
||||
conn: &Connection,
|
||||
conn: &mut PgConnection,
|
||||
passw: Option<&str>,
|
||||
) -> Result<(), UserError> {
|
||||
match passw {
|
||||
None => {
|
||||
conn.prepare("UPDATE users SET password = NULL WHERE id = ?1")?
|
||||
.execute((self.id,))?;
|
||||
sqlx::query("UPDATE users SET password = NULL WHERE id = $1")
|
||||
.bind(self.id)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
Some(passw) => {
|
||||
let hashed = User::hash_password(passw)?;
|
||||
conn.prepare("UPDATE users SET password = ?1 WHERE id = ?2")?
|
||||
.execute((hashed, self.id))?;
|
||||
sqlx::query("UPDATE users SET password = $1 WHERE id = $2")
|
||||
.bind(hashed)
|
||||
.bind(self.id)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -148,14 +175,18 @@ impl User {
|
||||
/// to do everything and probably should not be used as a regular account
|
||||
/// due to the ramifications of compromise. But it could be used for that,
|
||||
/// and have its name changed.
|
||||
pub fn create_infradmin(conn: &Connection) -> Result<User, UserError> {
|
||||
pub async fn create_infradmin(conn: &mut PgConnection) -> Result<User, UserError> {
|
||||
let mut u = User {
|
||||
id: Uuid::max(),
|
||||
handle: UserHandle::new("Infradmin")?,
|
||||
};
|
||||
conn.prepare("INSERT INTO users(id, handle) VALUES (?1, ?2)")?
|
||||
.execute((&u.id, &u.handle))?;
|
||||
u.regenerate_infradmin_password(conn)?;
|
||||
sqlx::query("INSERT INTO users(id, handle) VALUES ($1, $2)")
|
||||
.bind(u.id)
|
||||
.bind(u.handle.as_str())
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
u.regenerate_infradmin_password(conn).await?;
|
||||
|
||||
Ok(u)
|
||||
}
|
||||
@@ -178,9 +209,12 @@ impl User {
|
||||
/// to do everything and probably should not be used as a regular account
|
||||
/// due to the ramifications of compromise. But it could be used for that,
|
||||
/// and have its name changed.
|
||||
pub fn regenerate_infradmin_password(&mut self, conn: &Connection) -> Result<(), UserError> {
|
||||
pub async fn regenerate_infradmin_password(
|
||||
&mut self,
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<(), UserError> {
|
||||
let passw = auth::generate_token(auth::TokenSize::Char16);
|
||||
self.set_password(conn, Some(&passw))?;
|
||||
self.set_password(conn, Some(&passw)).await?;
|
||||
log::info!("[USERS] The infradmin account password has been (re)generated.");
|
||||
log::info!("[USERS] Handle: {}", self.handle.as_str());
|
||||
log::info!("[USERS] Password: {}", passw);
|
||||
@@ -194,13 +228,16 @@ impl User {
|
||||
/// for actions performed by Mnemosyne internally.
|
||||
/// It shall not be available for log-in.
|
||||
/// It should not have its name changed, and should be protected from that.
|
||||
pub fn create_systemuser(conn: &Connection) -> Result<User, UserError> {
|
||||
pub async fn create_systemuser(conn: &mut PgConnection) -> Result<User, UserError> {
|
||||
let u = User {
|
||||
id: Uuid::nil(),
|
||||
handle: UserHandle::new("Mnemosyne")?,
|
||||
};
|
||||
conn.prepare("INSERT INTO users(id, handle) VALUES (?1, ?2)")?
|
||||
.execute((&u.id, &u.handle))?;
|
||||
sqlx::query("INSERT INTO users(id, handle) VALUES ($1, $2)")
|
||||
.bind(u.id)
|
||||
.bind(u.handle.as_str())
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Ok(u)
|
||||
}
|
||||
@@ -216,22 +253,24 @@ impl User {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rusqlite::Error> for UserError {
|
||||
fn from(error: rusqlite::Error) -> Self {
|
||||
if let rusqlite::Error::SqliteFailure(err, Some(msg)) = &error
|
||||
&& err.extended_code == SQLITE_CONSTRAINT_UNIQUE
|
||||
&& msg.contains("handle")
|
||||
{
|
||||
return UserError::HandleAlreadyExists;
|
||||
impl From<sqlx::Error> for UserError {
|
||||
fn from(error: sqlx::Error) -> Self {
|
||||
if let sqlx::Error::Database(err) = &error {
|
||||
// Check for Postgres unique constraint violation (code 23505)
|
||||
if err.is_unique_violation() && err.message().contains("handle") {
|
||||
return UserError::HandleAlreadyExists;
|
||||
}
|
||||
}
|
||||
UserError::DatabaseError(DatabaseError::from(error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<argon2::password_hash::Error> for UserError {
|
||||
fn from(err: argon2::password_hash::Error) -> Self {
|
||||
UserError::PassHashError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for UserError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use rusqlite::Connection;
|
||||
use sqlx::PgConnection;
|
||||
|
||||
use crate::{database::DatabaseError, users::User};
|
||||
|
||||
@@ -17,13 +17,14 @@ pub enum Permission {
|
||||
RenameTags,
|
||||
DeleteTags,
|
||||
ChangePersonPrimaryName,
|
||||
#[allow(unused)]
|
||||
BrowseServerLogs,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn has_permission(
|
||||
pub async fn has_permission(
|
||||
&self,
|
||||
#[allow(unused)] conn: &Connection,
|
||||
#[allow(unused)] conn: &mut PgConnection,
|
||||
#[allow(unused)] permission: Permission,
|
||||
) -> Result<bool, DatabaseError> {
|
||||
// Infradmin and systemuser have all permissions
|
||||
|
||||
@@ -3,9 +3,9 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use rusqlite::{Connection, OptionalExtension};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{PgConnection, Row};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
@@ -46,11 +46,13 @@ pub enum SessionError {
|
||||
#[error("No session found with provided token")]
|
||||
NoSessionWithToken(String),
|
||||
}
|
||||
impl From<rusqlite::Error> for SessionError {
|
||||
fn from(error: rusqlite::Error) -> Self {
|
||||
|
||||
impl From<sqlx::Error> for SessionError {
|
||||
fn from(error: sqlx::Error) -> Self {
|
||||
SessionError::DatabaseError(DatabaseError::from(error))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for SessionError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
@@ -70,55 +72,89 @@ impl IntoResponse for SessionError {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn get_by_id(conn: &Connection, id: Uuid) -> Result<Session, SessionError> {
|
||||
let res = conn
|
||||
.prepare("SELECT user_id, expiry, revoked, revoked_at, revoked_by FROM sessions WHERE id = ?1")?
|
||||
.query_one((&id,), |r| Ok(Session {
|
||||
id,
|
||||
user_id: r.get(0)?,
|
||||
expiry: r.get(1)?,
|
||||
status: match r.get::<_, bool>(2)? {
|
||||
false => SessionStatus::Active,
|
||||
true => {
|
||||
SessionStatus::Revoked { revoked_at: r.get(3)?, revoked_by: r.get(4)? }
|
||||
}
|
||||
}
|
||||
})).optional()?;
|
||||
pub async fn get_by_id(conn: &mut PgConnection, id: Uuid) -> Result<Session, SessionError> {
|
||||
let row = sqlx::query(
|
||||
"SELECT user_id, expiry, revoked, revoked_at, revoked_by FROM sessions WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(conn)
|
||||
.await?;
|
||||
|
||||
match res {
|
||||
Some(s) => Ok(s),
|
||||
match row {
|
||||
Some(r) => {
|
||||
let revoked: bool = r.try_get("revoked")?;
|
||||
let status = if revoked {
|
||||
SessionStatus::Revoked {
|
||||
revoked_at: r.try_get("revoked_at")?,
|
||||
revoked_by: r.try_get("revoked_by")?,
|
||||
}
|
||||
} else {
|
||||
SessionStatus::Active
|
||||
};
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
user_id: r.try_get("user_id")?,
|
||||
expiry: r.try_get("expiry")?,
|
||||
status,
|
||||
})
|
||||
}
|
||||
None => Err(SessionError::NoSessionWithId(id)),
|
||||
}
|
||||
}
|
||||
pub fn get_by_token(conn: &Connection, token: &str) -> Result<Session, SessionError> {
|
||||
let hashed = Sha256::digest(token.as_bytes()).to_vec();
|
||||
let res = conn
|
||||
.prepare("SELECT id, user_id, expiry, revoked, revoked_at, revoked_by FROM sessions WHERE token = ?1")?
|
||||
.query_one((hashed,), |r| Ok(Session {
|
||||
id: r.get(0)?,
|
||||
user_id: r.get(1)?,
|
||||
expiry: r.get(2)?,
|
||||
status: match r.get::<_, bool>(3)? {
|
||||
false => SessionStatus::Active,
|
||||
true => {
|
||||
SessionStatus::Revoked { revoked_at: r.get(4)?, revoked_by: r.get(5)? }
|
||||
}
|
||||
}
|
||||
})).optional()?;
|
||||
|
||||
match res {
|
||||
Some(s) => Ok(s),
|
||||
pub async fn get_by_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Session, SessionError> {
|
||||
let hashed = Sha256::digest(token.as_bytes()).to_vec();
|
||||
let row = sqlx::query(
|
||||
"SELECT id, user_id, expiry, revoked, revoked_at, revoked_by FROM sessions WHERE token = $1",
|
||||
)
|
||||
.bind(&hashed)
|
||||
.fetch_optional(conn)
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some(r) => {
|
||||
let revoked: bool = r.try_get("revoked")?;
|
||||
let status = if revoked {
|
||||
SessionStatus::Revoked {
|
||||
revoked_at: r.try_get("revoked_at")?,
|
||||
revoked_by: r.try_get("revoked_by")?,
|
||||
}
|
||||
} else {
|
||||
SessionStatus::Active
|
||||
};
|
||||
|
||||
Ok(Session {
|
||||
id: r.try_get("id")?,
|
||||
user_id: r.try_get("user_id")?,
|
||||
expiry: r.try_get("expiry")?,
|
||||
status,
|
||||
})
|
||||
}
|
||||
None => Err(SessionError::NoSessionWithToken(token.to_string())),
|
||||
}
|
||||
}
|
||||
pub fn new_for_user(conn: &Connection, user: &User) -> Result<(Session, String), SessionError> {
|
||||
|
||||
pub async fn new_for_user(
|
||||
conn: &mut PgConnection,
|
||||
user: &User,
|
||||
) -> Result<(Session, String), SessionError> {
|
||||
let id = Uuid::now_v7();
|
||||
let token = auth::generate_token(auth::TokenSize::Char64);
|
||||
let hashed = Sha256::digest(token.as_bytes()).to_vec();
|
||||
let expiry = Utc::now() + Session::DEFAULT_PROLONGATION;
|
||||
|
||||
conn.prepare("INSERT INTO sessions(id, token, user_id, expiry) VALUES (?1, ?2, ?3, ?4)")?
|
||||
.execute((&id, &hashed, user.id, expiry))?;
|
||||
sqlx::query("INSERT INTO sessions(id, token, user_id, expiry) VALUES ($1, $2, $3, $4)")
|
||||
.bind(id)
|
||||
.bind(hashed)
|
||||
.bind(user.id)
|
||||
.bind(expiry)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
let s = Session {
|
||||
id,
|
||||
user_id: user.id,
|
||||
@@ -130,7 +166,8 @@ impl Session {
|
||||
|
||||
pub const DEFAULT_PROLONGATION: Duration = Duration::days(14);
|
||||
const PROLONGATION_THRESHOLD: Duration = Duration::hours(2);
|
||||
pub fn prolong(&mut self, conn: &Connection) -> Result<(), SessionError> {
|
||||
|
||||
pub async fn prolong(&mut self, conn: &mut PgConnection) -> Result<(), SessionError> {
|
||||
if self.expiry - Session::DEFAULT_PROLONGATION + Session::PROLONGATION_THRESHOLD
|
||||
> Utc::now()
|
||||
{
|
||||
@@ -138,22 +175,37 @@ impl Session {
|
||||
}
|
||||
|
||||
let expiry = Utc::now() + Session::DEFAULT_PROLONGATION;
|
||||
conn.prepare("UPDATE sessions SET expiry = ?1 WHERE id = ?2")?
|
||||
.execute((&expiry, &self.id))?;
|
||||
sqlx::query("UPDATE sessions SET expiry = $1 WHERE id = $2")
|
||||
.bind(expiry)
|
||||
.bind(self.id)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
self.expiry = expiry;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn revoke(&mut self, conn: &Connection, actor: Option<&User>) -> Result<(), SessionError> {
|
||||
pub async fn revoke(
|
||||
&mut self,
|
||||
conn: &mut PgConnection,
|
||||
actor: Option<&User>,
|
||||
) -> Result<(), SessionError> {
|
||||
let now = Utc::now();
|
||||
let id = actor.map(|u| u.id).unwrap_or(Uuid::nil());
|
||||
conn.prepare(
|
||||
"UPDATE sessions SET revoked = ?1, revoked_at = ?2, revoked_by = ?3 WHERE id = ?4",
|
||||
)?
|
||||
.execute((&true, &now, &id, &self.id))?;
|
||||
let actor_id = actor.map(|u| u.id).unwrap_or(Uuid::nil());
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE sessions SET revoked = $1, revoked_at = $2, revoked_by = $3 WHERE id = $4",
|
||||
)
|
||||
.bind(true)
|
||||
.bind(now)
|
||||
.bind(actor_id)
|
||||
.bind(self.id)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
self.status = SessionStatus::Revoked {
|
||||
revoked_at: now,
|
||||
revoked_by: id,
|
||||
revoked_by: actor_id,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
@@ -165,9 +217,11 @@ impl Session {
|
||||
let timestamp = self.id.get_timestamp().unwrap().to_unix();
|
||||
DateTime::from_timestamp_secs(timestamp.0 as i64).unwrap()
|
||||
}
|
||||
|
||||
pub fn is_expired_or_revoked(&self) -> bool {
|
||||
self.is_expired() || self.status.is_revoked()
|
||||
}
|
||||
|
||||
pub fn is_expired(&self) -> bool {
|
||||
self.expiry <= Utc::now()
|
||||
}
|
||||
|
||||
@@ -1,40 +1,37 @@
|
||||
use rusqlite::OptionalExtension;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
database,
|
||||
logs::{LogAction, LogEntry},
|
||||
users::{User, UserError},
|
||||
};
|
||||
|
||||
pub fn initialise_reserved_users_if_needed() -> Result<(), UserError> {
|
||||
let mut conn = database::conn()?;
|
||||
let tx = conn.transaction()?;
|
||||
pub async fn initialise_reserved_users_if_needed(pool: &PgPool) -> Result<(), UserError> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
if tx
|
||||
.prepare("SELECT handle FROM users WHERE id = ?1")?
|
||||
.query_one((&Uuid::nil(),), |_| Ok(()))
|
||||
.optional()?
|
||||
.is_none()
|
||||
{
|
||||
let u = User::create_systemuser(&tx)?;
|
||||
LogEntry::new(&tx, u, LogAction::Initialize)?;
|
||||
let systemuser_exists = sqlx::query("SELECT handle FROM users WHERE id = $1")
|
||||
.bind(Uuid::nil())
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.is_some();
|
||||
|
||||
if !systemuser_exists {
|
||||
let u = User::create_systemuser(&mut *tx).await?;
|
||||
LogEntry::new(&mut *tx, u, LogAction::Initialize).await?;
|
||||
}
|
||||
|
||||
if tx
|
||||
.prepare("SELECT handle FROM users WHERE id = ?1")?
|
||||
.query_one((&Uuid::max(),), |_| Ok(()))
|
||||
.optional()?
|
||||
.is_none()
|
||||
{
|
||||
User::create_infradmin(&tx)?;
|
||||
LogEntry::new(
|
||||
&tx,
|
||||
User::get_by_id(&tx, Uuid::nil())?,
|
||||
LogAction::RegenInfradmin,
|
||||
)?;
|
||||
let infradmin_exists = sqlx::query("SELECT handle FROM users WHERE id = $1")
|
||||
.bind(Uuid::max())
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.is_some();
|
||||
|
||||
if !infradmin_exists {
|
||||
User::create_infradmin(&mut *tx).await?;
|
||||
let u = User::get_by_id(&mut *tx, Uuid::max()).await?;
|
||||
LogEntry::new(&mut *tx, u, LogAction::RegenInfradmin).await?;
|
||||
}
|
||||
|
||||
tx.commit()?;
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user