use std::sync::Arc; use anyhow::Error; use aes_gcm::{aead::Aead, AeadCore, Aes256Gcm, Key, KeyInit}; use argon2::{password_hash::SaltString, Argon2, PasswordHasher}; use base64::{engine::general_purpose::STANDARD, Engine as _}; use futures_util::StreamExt; use giterated_models::{ messages::{ authentication::{ AuthenticationTokenRequest, AuthenticationTokenResponse, RegisterAccountRequest, RegisterAccountResponse, }, user::{ UserBioRequest, UserBioResponse, UserDisplayImageRequest, UserDisplayImageResponse, UserDisplayNameRequest, UserDisplayNameResponse, }, }, model::{instance::Instance, user::User}, }; use rsa::{ pkcs8::{EncodePrivateKey, EncodePublicKey}, rand_core::OsRng, RsaPrivateKey, RsaPublicKey, }; use serde_json::Value; use sqlx::{Either, PgPool}; use tokio::sync::Mutex; use crate::authentication::AuthenticationTokenGranter; use super::{AuthBackend, UserBackend}; pub struct UserAuth { pub pg_pool: PgPool, pub this_instance: Instance, pub auth_granter: Arc>, } impl UserAuth { pub fn new( pool: PgPool, this_instance: &Instance, granter: Arc>, ) -> Self { Self { pg_pool: pool, this_instance: this_instance.clone(), auth_granter: granter, } } } #[async_trait::async_trait] impl UserBackend for UserAuth { async fn display_name( &mut self, request: UserDisplayNameRequest, ) -> Result { let db_row = sqlx::query_as!( UserRow, r#"SELECT * FROM users WHERE username = $1"#, request.user.username ) .fetch_one(&self.pg_pool.clone()) .await .unwrap(); Ok(UserDisplayNameResponse { display_name: db_row.display_name, }) } async fn display_image( &mut self, request: UserDisplayImageRequest, ) -> Result { let db_row = sqlx::query_as!( UserRow, r#"SELECT * FROM users WHERE username = $1"#, request.user.username ) .fetch_one(&self.pg_pool.clone()) .await .unwrap(); Ok(UserDisplayImageResponse { image_url: db_row.image_url, }) } async fn bio(&mut self, request: UserBioRequest) -> Result { let db_row = sqlx::query_as!( UserRow, r#"SELECT * FROM users WHERE username = $1"#, request.user.username ) .fetch_one(&self.pg_pool.clone()) .await?; Ok(UserBioResponse { bio: db_row.bio }) } async fn exists(&mut self, user: &User) -> Result { Ok(sqlx::query_as!( UserRow, r#"SELECT * FROM users WHERE username = $1"#, user.username ) .fetch_one(&self.pg_pool.clone()) .await .is_err()) } async fn settings(&mut self, user: &User) -> Result, Error> { let settings = sqlx::query_as!( UserSettingRow, r#"SELECT * FROM user_settings WHERE username = $1"#, user.username ) .fetch_many(&self.pg_pool) .filter_map(|result| async move { if let Ok(Either::Right(row)) = result { Some(row) } else { None } }) .filter_map(|row| async move { if let Ok(value) = serde_json::from_str(&row.value) { Some((row.name, value)) } else { None } }) .collect::>() .await; Ok(settings) } async fn write_settings( &mut self, user: &User, settings: &[(String, Value)], ) -> Result<(), Error> { for (name, value) in settings { let serialized = serde_json::to_string(value)?; sqlx::query!("INSERT INTO user_settings VALUES ($1, $2, $3) ON CONFLICT (username, name) DO UPDATE SET value = $3", user.username, name, serialized) .execute(&self.pg_pool).await?; } Ok(()) } } #[async_trait::async_trait] impl AuthBackend for UserAuth { async fn register( &mut self, request: RegisterAccountRequest, ) -> Result { const BITS: usize = 2048; let private_key = RsaPrivateKey::new(&mut OsRng, BITS).unwrap(); let public_key = RsaPublicKey::from(&private_key); let key = { let mut target: [u8; 32] = [0; 32]; let mut index = 0; let mut iterator = request.password.as_bytes().iter(); while index < 32 { if let Some(next) = iterator.next() { target[index] = *next; index += 1; } else { iterator = request.password.as_bytes().iter(); } } target }; let key: &Key = &key.into(); let cipher = Aes256Gcm::new(key); let nonce = Aes256Gcm::generate_nonce(&mut OsRng); let ciphertext = cipher .encrypt(&nonce, private_key.to_pkcs8_der().unwrap().as_bytes()) .unwrap(); let private_key_enc = format!("{}#{}", STANDARD.encode(nonce), STANDARD.encode(ciphertext)); let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = argon2 .hash_password(request.password.as_bytes(), &salt) .unwrap() .to_string(); let user = match sqlx::query_as!( UserRow, r#"INSERT INTO users VALUES ($1, null, $2, null, null, $3, $4, $5) returning *"#, request.username, "example.com", password_hash, public_key .to_public_key_pem(rsa::pkcs8::LineEnding::LF) .unwrap(), private_key_enc ) .fetch_one(&self.pg_pool) .await { Ok(user) => user, Err(err) => { error!("Failed inserting into the database! {:?}", err); return Err(err.into()); } }; let mut granter = self.auth_granter.lock().await; let token = granter .create_token_for( &User { username: user.username, instance: self.this_instance.clone(), }, &self.this_instance, ) .await; Ok(RegisterAccountResponse { token }) } async fn login( &mut self, _request: AuthenticationTokenRequest, ) -> Result { todo!() } } #[allow(unused)] #[derive(Debug, sqlx::FromRow)] struct UserRow { pub username: String, pub image_url: Option, pub display_name: Option, pub bio: Option, pub email: Option, pub password: String, pub public_key: String, pub enc_private_key: Vec, } #[allow(unused)] #[derive(Debug, sqlx::FromRow)] struct UserSettingRow { pub username: String, pub name: String, pub value: String, }