use std::sync::Arc; use anyhow::Error; use aes_gcm::{aead::Aead, AeadCore, Aes256Gcm, Key, KeyInit}; use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; use base64::{engine::general_purpose::STANDARD, Engine as _}; use giterated_models::{ model::{authenticated::UserAuthenticationToken, instance::Instance, user::User}, operation::instance::{AuthenticationTokenRequest, RegisterAccountRequest}, }; use rsa::{ pkcs8::{EncodePrivateKey, EncodePublicKey}, rand_core::OsRng, RsaPrivateKey, RsaPublicKey, }; use secrecy::ExposeSecret; use sqlx::PgPool; use tokio::sync::Mutex; use crate::authentication::AuthenticationTokenGranter; use super::{AuthBackend, SettingsBackend, UserBackend}; pub struct UserAuth { pub pg_pool: PgPool, pub this_instance: Instance, pub auth_granter: Arc>, pub settings_provider: Arc>, } impl UserAuth { pub fn new( pool: PgPool, this_instance: &Instance, granter: Arc>, settings_provider: Arc>, ) -> Self { Self { pg_pool: pool, this_instance: this_instance.clone(), auth_granter: granter, settings_provider, } } } #[async_trait::async_trait] impl UserBackend for UserAuth { async fn exists(&mut self, user: &User) -> Result { Ok(sqlx::query_as!( UserRow, r#"SELECT * FROM users WHERE username = $1"#, user.username ) .fetch_one(&self.pg_pool.clone()) .await .is_ok()) } } #[async_trait::async_trait] impl AuthBackend for UserAuth { async fn register( &mut self, request: RegisterAccountRequest, ) -> Result { const BITS: usize = 2048; let private_key = RsaPrivateKey::new(&mut OsRng, BITS).unwrap(); let public_key = RsaPublicKey::from(&private_key); let key = { let mut target: [u8; 32] = [0; 32]; let mut index = 0; let mut iterator = request.password.expose_secret().0.as_bytes().iter(); while index < 32 { if let Some(next) = iterator.next() { target[index] = *next; index += 1; } else { iterator = request.password.expose_secret().0.as_bytes().iter(); } } target }; let key: &Key = &key.into(); let cipher = Aes256Gcm::new(key); let nonce = Aes256Gcm::generate_nonce(&mut OsRng); let ciphertext = cipher .encrypt(&nonce, private_key.to_pkcs8_der().unwrap().as_bytes()) .unwrap(); let private_key_enc = format!("{}#{}", STANDARD.encode(nonce), STANDARD.encode(ciphertext)); let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = argon2 .hash_password(request.password.expose_secret().0.as_bytes(), &salt) .unwrap() .to_string(); let user = match sqlx::query_as!( UserRow, r#"INSERT INTO users VALUES ($1, $2, $3, $4, $5) returning *"#, request.username, "example.com", password_hash, public_key .to_public_key_pem(rsa::pkcs8::LineEnding::LF) .unwrap(), private_key_enc ) .fetch_one(&self.pg_pool) .await { Ok(user) => user, Err(err) => { error!("Failed inserting into the database! {:?}", err); return Err(err.into()); } }; let mut granter = self.auth_granter.lock().await; let token = granter .create_token_for( &User { username: user.username, instance: self.this_instance.clone(), }, &self.this_instance, ) .await; Ok(UserAuthenticationToken::from(token)) } async fn login( &mut self, source: &Instance, request: AuthenticationTokenRequest, ) -> Result { info!("fetching!"); let user = sqlx::query_as!( UserRow, r#"SELECT * FROM users WHERE username = $1"#, request.username ) .fetch_one(&self.pg_pool) .await?; let hash = PasswordHash::new(&user.password).unwrap(); if Argon2::default() .verify_password(request.password.expose_secret().0.as_bytes(), &hash) .is_err() { info!("invalid password"); return Err(Error::from(AuthenticationError::InvalidPassword)); } let mut granter = self.auth_granter.lock().await; let token = granter .create_token_for( &User { username: user.username, instance: self.this_instance.clone(), }, &source, ) .await; Ok(UserAuthenticationToken::from(token)) } } #[allow(unused)] #[derive(Debug, sqlx::FromRow)] struct UserRow { pub username: String, pub email: Option, pub password: String, pub public_key: String, pub enc_private_key: Vec, } #[derive(Debug, thiserror::Error)] pub enum AuthenticationError { #[error("invalid password")] InvalidPassword, }