diff --git a/Cargo.lock b/Cargo.lock index db3e428..1107fa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -652,6 +652,7 @@ dependencies = [ "tokio", "tokio-tungstenite", "toml", + "tower", "tracing", "tracing-subscriber", ] @@ -2230,6 +2231,23 @@ dependencies = [ ] [[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + +[[package]] name = "tower-service" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/Cargo.toml b/Cargo.toml index 0024ba0..838bcb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ tokio-tungstenite = "*" tokio = { version = "1.32.0", features = [ "full" ] } tracing = "*" futures-util = "*" -serde = { version = "1", features = [ "derive" ]} +serde = { version = "1.0.188", features = [ "derive" ]} serde_json = "1.0" tracing-subscriber = "0.3" base64 = "0.21.3" @@ -22,6 +22,7 @@ reqwest = "*" argon2 = "*" aes-gcm = "0.10.2" semver = {version = "*", features = ["serde"]} +tower = "*" toml = { version = "0.7" } diff --git a/src/authentication.rs b/src/authentication.rs index b65ff15..55be9ff 100644 --- a/src/authentication.rs +++ b/src/authentication.rs @@ -13,7 +13,7 @@ use crate::{ }, InstanceAuthenticated, }, - model::{instance::Instance, user::User}, + model::{authenticated::UserAuthenticationToken, instance::Instance, user::User}, }; #[derive(Debug, Serialize, Deserialize)] @@ -74,21 +74,10 @@ impl AuthenticationTokenGranter { pub async fn token_request( &mut self, - raw_request: InstanceAuthenticated, + issued_for: impl ToOwned, + username: String, + password: String, ) -> Result { - let request = raw_request.inner().await; - - info!("Ensuring token request is from the same instance..."); - raw_request - .validate(&Instance { - url: String::from("giterated.dev"), - }) - .await - .unwrap(); - - let secret_key = self.config["authentication"]["secret_key"] - .as_str() - .unwrap(); let private_key = { let mut file = File::open( self.config["giterated"]["keys"]["private"] @@ -104,20 +93,14 @@ impl AuthenticationTokenGranter { key }; - if request.secret_key != secret_key { - error!("Incorrect secret key!"); - - panic!() - } - let encoding_key = EncodingKey::from_rsa_pem(&private_key).unwrap(); let claims = UserTokenMetadata { user: User { - username: request.username.clone(), + username, instance: self.instance.clone(), }, - generated_for: raw_request.instance.clone(), + generated_for: issued_for.to_owned(), exp: (SystemTime::UNIX_EPOCH.elapsed().unwrap() + std::time::Duration::from_secs(24 * 60 * 60)) .as_secs(), @@ -135,53 +118,25 @@ impl AuthenticationTokenGranter { pub async fn extension_request( &mut self, - raw_request: InstanceAuthenticated, + issued_for: &Instance, + token: UserAuthenticationToken, ) -> Result { - let request = raw_request.inner().await; - - // let server_public_key = { - // let mut file = File::open(self.config["keys"]["public"].as_str().unwrap()) - // .await - // .unwrap(); - - // let mut key = String::default(); - // file.read_to_string(&mut key).await.unwrap(); - - // key - // }; - - let server_public_key = public_key(&Instance { - url: String::from("giterated.dev"), - }) - .await - .unwrap(); + let server_public_key = public_key(&self.instance).await.unwrap(); let verification_key = DecodingKey::from_rsa_pem(server_public_key.as_bytes()).unwrap(); let data: TokenData = decode( - &request.token, + token.as_ref(), &verification_key, &Validation::new(Algorithm::RS256), ) .unwrap(); - info!("Token Extension Request Token validated"); - - let secret_key = self.config["authentication"]["secret_key"] - .as_str() - .unwrap(); - - if request.secret_key != secret_key { - error!("Incorrect secret key!"); - + if data.claims.generated_for != *issued_for { panic!() } - // Validate request - raw_request - .validate(&data.claims.generated_for) - .await - .unwrap(); - info!("Validated request for key extension"); + + info!("Token Extension Request Token validated"); let private_key = { let mut file = File::open( @@ -203,7 +158,7 @@ impl AuthenticationTokenGranter { let claims = UserTokenMetadata { // TODO: Probably exploitable user: data.claims.user, - generated_for: data.claims.generated_for, + generated_for: issued_for.clone(), exp: (SystemTime::UNIX_EPOCH.elapsed().unwrap() + std::time::Duration::from_secs(24 * 60 * 60)) .as_secs(), diff --git a/src/backend/git.rs b/src/backend/git.rs index 06a7559..892b8f4 100644 --- a/src/backend/git.rs +++ b/src/backend/git.rs @@ -212,26 +212,9 @@ impl GitBackend { impl RepositoryBackend for GitBackend { async fn create_repository( &mut self, - raw_request: &ValidatedUserAuthenticated, + user: &User, + request: &CreateRepositoryRequest, ) -> Result { - let request = raw_request.inner().await; - - // let public_key = public_key(&Instance { - // url: String::from("giterated.dev"), - // }) - // .await - // .unwrap(); - // - // match raw_request.validate(public_key).await { - // Ok(_) => info!("Request was validated"), - // Err(err) => { - // error!("Failed to validate request: {:?}", err); - // panic!(); - // } - // } - // - // info!("Request was valid!"); - // Check if repository already exists in the database if let Ok(repository) = self .find_by_owner_user_name(&request.owner, &request.name) @@ -297,11 +280,9 @@ impl RepositoryBackend for GitBackend { async fn repository_info( &mut self, - // TODO: Allow non-authenticated??? - raw_request: &ValidatedUserAuthenticated, + requester: Option<&User>, + request: &RepositoryInfoRequest, ) -> Result { - let request = raw_request.inner().await; - let repository = match self .find_by_owner_user_name( // &request.owner.instance.url, @@ -314,7 +295,17 @@ impl RepositoryBackend for GitBackend { Err(err) => return Err(Box::new(err).into()), }; - if !repository.can_user_view_repository(Some(&raw_request.user)) { + if let Some(requester) = requester { + if !repository.can_user_view_repository(Some(&requester)) { + return Err(Box::new(GitBackendError::RepositoryNotFound { + owner_user: request.repository.owner.to_string(), + name: request.repository.name.clone(), + }) + .into()); + } + } else if matches!(repository.visibility, RepositoryVisibility::Private) { + // Unauthenticated users can never view private repositories + return Err(Box::new(GitBackendError::RepositoryNotFound { owner_user: request.repository.owner.to_string(), name: request.repository.name.clone(), @@ -446,9 +437,10 @@ impl RepositoryBackend for GitBackend { }) } - fn repository_file_inspect( + async fn repository_file_inspect( &mut self, - _request: &ValidatedUserAuthenticated, + requester: Option<&User>, + _request: &RepositoryFileInspectRequest, ) -> Result { todo!() } diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 8624fdf..a861980 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -35,15 +35,18 @@ use crate::{ pub trait RepositoryBackend: IssuesBackend { async fn create_repository( &mut self, - request: &ValidatedUserAuthenticated, + user: &User, + request: &CreateRepositoryRequest, ) -> Result; async fn repository_info( &mut self, - request: &ValidatedUserAuthenticated, + requester: Option<&User>, + request: &RepositoryInfoRequest, ) -> Result; - fn repository_file_inspect( + async fn repository_file_inspect( &mut self, - request: &ValidatedUserAuthenticated, + requester: Option<&User>, + request: &RepositoryFileInspectRequest, ) -> Result; async fn repositories_for_user(&mut self, user: &User) -> Result, Error>; @@ -90,6 +93,7 @@ pub trait UserBackend: AuthBackend { ) -> Result; async fn bio(&mut self, request: UserBioRequest) -> Result; + async fn exists(&mut self, user: &User) -> Result; } #[async_trait::async_trait] diff --git a/src/backend/user.rs b/src/backend/user.rs index 912c7b4..cff3644 100644 --- a/src/backend/user.rs +++ b/src/backend/user.rs @@ -100,6 +100,17 @@ impl UserBackend for UserAuth { Ok(UserBioResponse { bio: db_row.bio }) } + + async fn exists(&mut self, user: &User) -> Result { + Ok(sqlx::query_as!( + UserRow, + r#"SELECT * FROM users WHERE username = $1"#, + user.username + ) + .fetch_one(&self.pg_pool.clone()) + .await + .is_err()) + } } #[async_trait::async_trait] diff --git a/src/connection.rs b/src/connection.rs index 28ead95..4c39159 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,9 +1,15 @@ -use std::{collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc}; +pub mod authentication; +pub mod handshake; +pub mod repository; +pub mod user; +pub mod wrapper; + +use std::{any::type_name, collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc}; use anyhow::Error; use futures_util::{stream::StreamExt, SinkExt}; use semver::Version; -use serde::Serialize; +use serde::{de::DeserializeOwned, Serialize}; use tokio::{ net::TcpStream, sync::{ @@ -31,7 +37,7 @@ use crate::{ UserMessage, UserMessageKind, UserMessageRequest, UserMessageResponse, UserRepositoriesResponse, }, - MessageKind, + ErrorMessage, MessageKind, }, model::{ instance::{Instance, InstanceMeta}, @@ -41,6 +47,16 @@ use crate::{ validate_version, version, }; +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("connection error message {0}")] + ErrorMessage(#[from] ErrorMessage), + #[error("connection should close")] + Shutdown, + #[error("internal error {0}")] + InternalError(#[from] Error), +} + pub struct RawConnection { pub task: JoinHandle<()>, } @@ -63,523 +79,153 @@ pub struct Connections { } pub async fn connection_worker( - mut socket: WebSocketStream, - listeners: Arc>, - connections: Arc>, - backend: Arc>, - user_backend: Arc>, - auth_granter: Arc>, - discovery_backend: Arc>, - addr: SocketAddr, -) { - let mut handshaked = false; + mut socket: &mut WebSocketStream, + handshaked: &mut bool, + listeners: &Arc>, + connections: &Arc>, + backend: &Arc>, + user_backend: &Arc>, + auth_granter: &Arc>, + discovery_backend: &Arc>, + addr: &SocketAddr, +) -> Result<(), ConnectionError> { let this_instance = Instance { url: String::from("giterated.dev"), }; - while let Some(message) = socket.next().await { - let message = match message { - Ok(message) => message, - Err(err) => { - error!("Error reading message: {:?}", err); - continue; - } - }; + let message = socket + .next() + .await + .ok_or_else(|| ConnectionError::Shutdown)? + .map_err(|e| Error::from(e))?; + + let payload = match message { + Message::Text(text) => text.into_bytes(), + Message::Binary(bytes) => bytes, + Message::Ping(_) => return Ok(()), + Message::Pong(_) => return Ok(()), + Message::Close(_) => { + info!("Closing connection with {}.", addr); + + return Err(ConnectionError::Shutdown); + } + _ => unreachable!(), + }; - let payload = match message { - Message::Text(text) => text.into_bytes(), - Message::Binary(bytes) => bytes, - Message::Ping(_) => continue, - Message::Pong(_) => continue, - Message::Close(_) => { - info!("Closing connection with {}.", addr); + let message = serde_json::from_slice::(&payload).map_err(|e| Error::from(e))?; - return; + if let MessageKind::Handshake(handshake) = message { + match handshake { + HandshakeMessage::Initiate(request) => { + unimplemented!() } - _ => unreachable!(), - }; - - let message = match serde_json::from_slice::(&payload) { - Ok(message) => message, - Err(err) => { - error!("Error deserializing message from {}: {:?}", addr, err); - continue; + HandshakeMessage::Response(response) => { + unimplemented!() } - }; - - // info!("Read payload: {}", std::str::from_utf8(&payload).unwrap()); - - if let MessageKind::Handshake(handshake) = message { - match handshake { - HandshakeMessage::Initiate(request) => { - // Send HandshakeMessage::Response - let message = HandshakeResponse { - identity: this_instance.clone(), - version: version(), - }; - - let version_check = validate_version(&request.version); - - let _result = if !version_check { - error!( - "Version compatibility failure! Our Version: {}, Their Version: {}", - Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()) - .unwrap(), - request.version - ); - - send( - &mut socket, - MessageKind::Handshake(HandshakeMessage::Finalize(HandshakeFinalize { - success: false, - })), - ) - .await - } else { - send( - &mut socket, - MessageKind::Handshake(HandshakeMessage::Response(message)), - ) - .await - }; - - continue; - } - HandshakeMessage::Response(response) => { - // Check version - let message = if validate_version(&response.version) { - error!( - "Version compatibility failure! Our Version: {}, Their Version: {}", - version(), - response.version - ); - - HandshakeFinalize { success: false } - } else { - info!("Connected with a compatible version"); - - HandshakeFinalize { success: true } - }; - - let _result = send( - &mut socket, - MessageKind::Handshake(HandshakeMessage::Finalize(message)), - ) - .await; - - continue; - } - HandshakeMessage::Finalize(response) => { - if !response.success { - error!("Error during handshake, aborting connection"); - return; - } - - handshaked = true; - - // Send HandshakeMessage::Finalize - let message = HandshakeFinalize { success: true }; - - let _result = send( - &mut socket, - MessageKind::Handshake(HandshakeMessage::Finalize(message)), - ) - .await; - - continue; - } + HandshakeMessage::Finalize(response) => { + unimplemented!() } } + } - if !handshaked { - continue; - } + if !*handshaked { + return Ok(()); + } - if let MessageKind::Repository(repository) = &message { - if repository.target.instance != this_instance { - info!("Forwarding command to {}", repository.target.instance.url); - // We need to send this command to a different instance - - let mut listener = send_and_get_listener(message, &listeners, &connections).await; - - // Wait for response - while let Ok(message) = listener.recv().await { - if let MessageKind::Repository(RepositoryMessage { - command: RepositoryMessageKind::Response(_), - .. - }) = message - { - let _result = send(&mut socket, message).await; - } - } - continue; - } else { - // This message is targeting this instance - match &repository.command { - RepositoryMessageKind::Request(request) => match request.clone() { - RepositoryRequest::CreateRepository(request) => { - let mut backend = backend.lock().await; - let request = request.validate().await.unwrap(); - let response = backend.create_repository(&request).await; - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(backend); - - let _result = send( - &mut socket, - MessageKind::Repository(RepositoryMessage { - target: repository.target.clone(), - command: RepositoryMessageKind::Response( - RepositoryResponse::CreateRepository(response), - ), - }), - ) - .await; - - continue; - } - RepositoryRequest::RepositoryFileInspect(request) => { - let mut backend = backend.lock().await; - let request = request.validate().await.unwrap(); - let response = backend.repository_file_inspect(&request); - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(backend); - - let _result = send( - &mut socket, - MessageKind::Repository(RepositoryMessage { - target: repository.target.clone(), - command: RepositoryMessageKind::Response( - RepositoryResponse::RepositoryFileInspection(response), - ), - }), - ) - .await; - - continue; - } - RepositoryRequest::RepositoryInfo(request) => { - let mut backend = backend.lock().await; - let request = request.validate().await.unwrap(); - let response = backend.repository_info(&request).await; - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(backend); - - let _result = send( - &mut socket, - MessageKind::Repository(RepositoryMessage { - target: repository.target.clone(), - command: RepositoryMessageKind::Response( - RepositoryResponse::RepositoryInfo(response), - ), - }), - ) - .await; - - continue; - } - RepositoryRequest::IssuesCount(request) => { - let request = &request.validate().await.unwrap(); - - let mut backend = backend.lock().await; - let response = backend.issues_count(request); - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(backend); - - let _result = send( - &mut socket, - MessageKind::Repository(RepositoryMessage { - target: repository.target.clone(), - command: RepositoryMessageKind::Response( - RepositoryResponse::IssuesCount(response), - ), - }), - ) - .await; - - continue; - } - RepositoryRequest::IssueLabels(request) => { - let request = &request.validate().await.unwrap(); - - let mut backend = backend.lock().await; - let response = backend.issue_labels(request); - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(backend); - - let _result = send( - &mut socket, - MessageKind::Repository(RepositoryMessage { - target: repository.target.clone(), - command: RepositoryMessageKind::Response( - RepositoryResponse::IssueLabels(response), - ), - }), - ) - .await; - - continue; - } - RepositoryRequest::Issues(request) => { - let request = request.validate().await.unwrap(); - - let mut backend = backend.lock().await; - let response = backend.issues(&request); - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(backend); - - let _result = send( - &mut socket, - MessageKind::Repository(RepositoryMessage { - target: repository.target.clone(), - command: RepositoryMessageKind::Response( - RepositoryResponse::Issues(response), - ), - }), - ) - .await; - - continue; - } - }, - RepositoryMessageKind::Response(_response) => { - unreachable!() - } + if let MessageKind::Repository(repository) = &message { + if repository.target.instance != this_instance { + info!("Forwarding command to {}", repository.target.instance.url); + // We need to send this command to a different instance + + let mut listener = send_and_get_listener(message, &listeners, &connections).await; + + // Wait for response + while let Ok(message) = listener.recv().await { + if let MessageKind::Repository(RepositoryMessage { + command: RepositoryMessageKind::Response(_), + .. + }) = message + { + let _result = send(&mut socket, message).await; } } - } - if let MessageKind::Authentication(authentication) = &message { - match authentication { - AuthenticationMessage::Request(request) => match request { - AuthenticationRequest::AuthenticationToken(token) => { - let mut granter = auth_granter.lock().await; - - let response = granter.token_request(token.clone()).await.unwrap(); - drop(granter); - - let _result = send( - &mut socket, - MessageKind::Authentication(AuthenticationMessage::Response( - AuthenticationResponse::AuthenticationToken(response), - )), - ) - .await; - - continue; + return Ok(()); + } else { + // This message is targeting this instance + match &repository.command { + RepositoryMessageKind::Request(request) => match request.clone() { + RepositoryRequest::CreateRepository(request) => { + unimplemented!(); } - AuthenticationRequest::TokenExtension(request) => { - let mut granter = auth_granter.lock().await; - - let response = granter - .extension_request(request.clone()) - .await - .unwrap_or(TokenExtensionResponse { new_token: None }); - drop(granter); - - let _result = send( - &mut socket, - MessageKind::Authentication(AuthenticationMessage::Response( - AuthenticationResponse::TokenExtension(response), - )), - ) - .await; - - continue; + RepositoryRequest::RepositoryFileInspect(request) => { + unimplemented!() } - AuthenticationRequest::RegisterAccount(request) => { - let request = request.inner().await.clone(); - - let mut user_backend = user_backend.lock().await; - - let response = user_backend.register(request.clone()).await.unwrap(); - drop(user_backend); - - let _result = send( - &mut socket, - MessageKind::Authentication(AuthenticationMessage::Response( - AuthenticationResponse::RegisterAccount(response), - )), - ) - .await; - - continue; + RepositoryRequest::RepositoryInfo(request) => { + unimplemented!() + } + RepositoryRequest::IssuesCount(request) => { + unimplemented!() + } + RepositoryRequest::IssueLabels(request) => { + unimplemented!() + } + RepositoryRequest::Issues(request) => { + unimplemented!(); } }, - AuthenticationMessage::Response(_) => unreachable!(), + RepositoryMessageKind::Response(_response) => { + unreachable!() + } } } + } - if let MessageKind::Discovery(message) = &message { - let mut backend = discovery_backend.lock().await; - backend.try_handle(message).await.unwrap(); - - continue; + if let MessageKind::Authentication(authentication) = &message { + match authentication { + AuthenticationMessage::Request(request) => match request { + AuthenticationRequest::AuthenticationToken(token) => { + unimplemented!() + } + AuthenticationRequest::TokenExtension(request) => { + unimplemented!() + } + AuthenticationRequest::RegisterAccount(request) => { + unimplemented!() + } + }, + AuthenticationMessage::Response(_) => unreachable!(), } + } - if let MessageKind::User(message) = &message { - match &message.message { - UserMessageKind::Request(request) => match request { - UserMessageRequest::DisplayName(request) => { - let mut user_backend = user_backend.lock().await; - - let response = user_backend.display_name(request.clone()).await; - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(user_backend); - - let _result = send( - &mut socket, - MessageKind::User(UserMessage { - instance: message.instance.clone(), - message: UserMessageKind::Response( - UserMessageResponse::DisplayName(response), - ), - }), - ) - .await; - - continue; - } - UserMessageRequest::DisplayImage(request) => { - let mut user_backend = user_backend.lock().await; - - let response = user_backend.display_image(request.clone()).await; - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(user_backend); - - let _result = send( - &mut socket, - MessageKind::User(UserMessage { - instance: message.instance.clone(), - message: UserMessageKind::Response( - UserMessageResponse::DisplayImage(response), - ), - }), - ) - .await; - - continue; - } - UserMessageRequest::Bio(request) => { - let mut user_backend = user_backend.lock().await; - - let response = user_backend.bio(request.clone()).await; - - let response = match response { - Ok(response) => response, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(user_backend); - - let _result = send( - &mut socket, - MessageKind::User(UserMessage { - instance: message.instance.clone(), - message: UserMessageKind::Response(UserMessageResponse::Bio( - response, - )), - }), - ) - .await; - - continue; - } - UserMessageRequest::Repositories(request) => { - let mut repository_backend = backend.lock().await; - - let repositories = repository_backend - .repositories_for_user(&request.user) - .await; - - let repositories = match repositories { - Ok(repositories) => repositories, - Err(err) => { - error!("Error handling request: {:?}", err); - continue; - } - }; - drop(repository_backend); - - let response = UserRepositoriesResponse { repositories }; - - let _result = send( - &mut socket, - MessageKind::User(UserMessage { - instance: message.instance.clone(), - message: UserMessageKind::Response( - UserMessageResponse::Repositories(response), - ), - }), - ) - .await; - - continue; - } - }, - UserMessageKind::Response(_) => unreachable!(), - } + if let MessageKind::Discovery(message) = &message { + let mut backend = discovery_backend.lock().await; + backend.try_handle(message).await?; + + return Ok(()); + } + + if let MessageKind::User(message) = &message { + match &message.message { + UserMessageKind::Request(request) => match request { + UserMessageRequest::DisplayName(request) => { + unimplemented!() + } + UserMessageRequest::DisplayImage(request) => { + unimplemented!() + } + UserMessageRequest::Bio(request) => { + unimplemented!() + } + UserMessageRequest::Repositories(request) => { + unimplemented!() + } + }, + UserMessageKind::Response(_) => unreachable!(), } } - info!("Connection closed"); + Ok(()) } async fn send_and_get_listener( @@ -596,6 +242,7 @@ async fn send_and_get_listener( MessageKind::Authentication(_) => todo!(), MessageKind::Discovery(_) => todo!(), MessageKind::User(user) => todo!(), + MessageKind::Error(_) => todo!(), }; let target = match (&instance, &user, &repository) { @@ -631,8 +278,56 @@ async fn send( message: T, ) -> Result<(), Error> { socket - .send(Message::Binary(serde_json::to_vec(&message).unwrap())) + .send(Message::Binary(serde_json::to_vec(&message)?)) .await?; Ok(()) } + +#[derive(Debug, thiserror::Error)] +#[error("handler did not handle")] +pub struct HandlerUnhandled; + +pub trait MessageHandling { + fn message_type() -> &'static str; +} + +impl MessageHandling<(T1,), M, R> for F +where + F: FnOnce(T1) -> R, + T1: Serialize + DeserializeOwned, +{ + fn message_type() -> &'static str { + type_name::() + } +} + +impl MessageHandling<(T1, T2), M, R> for F +where + F: FnOnce(T1, T2) -> R, + T1: Serialize + DeserializeOwned, +{ + fn message_type() -> &'static str { + type_name::() + } +} + +impl MessageHandling<(T1, T2, T3), M, R> for F +where + F: FnOnce(T1, T2, T3) -> R, + T1: Serialize + DeserializeOwned, +{ + fn message_type() -> &'static str { + type_name::() + } +} + +impl MessageHandling<(T1, T2, T3, T4), M, R> for F +where + F: FnOnce(T1, T2, T3, T4) -> R, + T1: Serialize + DeserializeOwned, +{ + fn message_type() -> &'static str { + type_name::() + } +} diff --git a/src/connection/authentication.rs b/src/connection/authentication.rs new file mode 100644 index 0000000..b773f5e --- /dev/null +++ b/src/connection/authentication.rs @@ -0,0 +1,160 @@ +use anyhow::Error; +use thiserror::Error; + +use crate::messages::authentication::{AuthenticationMessage, AuthenticationResponse}; +use crate::model::authenticated::MessageHandler; +use crate::{ + messages::{authentication::AuthenticationRequest, MessageKind}, + model::authenticated::{AuthenticatedInstance, NetworkMessage, State}, +}; + +use super::wrapper::ConnectionState; +use super::HandlerUnhandled; + +pub async fn authentication_handle( + message: &NetworkMessage, + state: &ConnectionState, +) -> Result<(), Error> { + let message_kind: MessageKind = serde_json::from_slice(&message.0).unwrap(); + + match message_kind { + MessageKind::Authentication(AuthenticationMessage::Request(request)) => match request { + AuthenticationRequest::RegisterAccount(_) => { + register_account_request + .handle_message(message, state) + .await + } + AuthenticationRequest::AuthenticationToken(_) => { + authentication_token_request + .handle_message(message, state) + .await + } + AuthenticationRequest::TokenExtension(_) => { + token_extension_request.handle_message(message, state).await + } + }, + _ => Err(Error::from(HandlerUnhandled)), + } +} + +async fn register_account_request( + State(connection_state): State, + request: MessageKind, + instance: AuthenticatedInstance, +) -> Result<(), AuthenticationConnectionError> { + let request = if let MessageKind::Authentication(AuthenticationMessage::Request( + AuthenticationRequest::RegisterAccount(request), + )) = request + { + request + } else { + return Err(AuthenticationConnectionError::InvalidRequest); + }; + + if *instance.inner() != connection_state.instance { + return Err(AuthenticationConnectionError::SameInstance); + } + + let mut user_backend = connection_state.user_backend.lock().await; + + let response = user_backend + .register(request.clone()) + .await + .map_err(|e| AuthenticationConnectionError::Registration(e))?; + drop(user_backend); + + connection_state + .send(MessageKind::Authentication( + AuthenticationMessage::Response(AuthenticationResponse::RegisterAccount(response)), + )) + .await + .map_err(|e| AuthenticationConnectionError::Sending(e))?; + + Ok(()) +} + +async fn authentication_token_request( + State(connection_state): State, + request: MessageKind, + instance: AuthenticatedInstance, +) -> Result<(), AuthenticationConnectionError> { + let request = if let MessageKind::Authentication(AuthenticationMessage::Request( + AuthenticationRequest::AuthenticationToken(request), + )) = request + { + request + } else { + return Err(AuthenticationConnectionError::InvalidRequest); + }; + + let issued_for = instance.inner().clone(); + + let mut token_granter = connection_state.auth_granter.lock().await; + + let response = token_granter + .token_request(issued_for, request.username, request.password) + .await + .map_err(|e| AuthenticationConnectionError::TokenIssuance(e))?; + + connection_state + .send(MessageKind::Authentication( + AuthenticationMessage::Response(AuthenticationResponse::AuthenticationToken(response)), + )) + .await + .map_err(|e| AuthenticationConnectionError::Sending(e))?; + + Ok(()) +} + +async fn token_extension_request( + State(connection_state): State, + request: MessageKind, + instance: AuthenticatedInstance, +) -> Result<(), AuthenticationConnectionError> { + let request = if let MessageKind::Authentication(AuthenticationMessage::Request( + AuthenticationRequest::TokenExtension(request), + )) = request + { + request + } else { + return Err(AuthenticationConnectionError::InvalidRequest); + }; + + let issued_for = instance.inner().clone(); + + let mut token_granter = connection_state.auth_granter.lock().await; + + let response = token_granter + .extension_request(&issued_for, request.token) + .await + .map_err(|e| AuthenticationConnectionError::TokenIssuance(e))?; + + connection_state + .send(MessageKind::Authentication( + AuthenticationMessage::Response(AuthenticationResponse::TokenExtension(response)), + )) + .await + .map_err(|e| AuthenticationConnectionError::Sending(e))?; + + Ok(()) +} + +async fn verify(state: ConnectionState) { + register_account_request + .handle_message(&NetworkMessage(vec![]), &state) + .await; +} + +#[derive(Debug, Error)] +pub enum AuthenticationConnectionError { + #[error("the request was invalid")] + InvalidRequest, + #[error("request must be from the same instance")] + SameInstance, + #[error("issue during registration {0}")] + Registration(Error), + #[error("sending error")] + Sending(Error), + #[error("error issuing token")] + TokenIssuance(Error), +} diff --git a/src/connection/handshake.rs b/src/connection/handshake.rs new file mode 100644 index 0000000..bf77de9 --- /dev/null +++ b/src/connection/handshake.rs @@ -0,0 +1,128 @@ +use std::{str::FromStr, sync::atomic::Ordering}; + +use anyhow::Error; +use semver::Version; +use thiserror::Error; + +use crate::model::authenticated::MessageHandler; +use crate::{ + connection::ConnectionError, + handshake::{HandshakeFinalize, HandshakeResponse, InitiateHandshake}, + model::authenticated::{AuthenticatedInstance, Message, NetworkMessage, State}, + validate_version, +}; + +use super::{wrapper::ConnectionState, HandlerUnhandled}; + +pub async fn handshake_handle( + message: &NetworkMessage, + state: &ConnectionState, +) -> Result<(), Error> { + if initiate_handshake + .handle_message(&message, state) + .await + .is_ok() + { + Ok(()) + } else if handshake_response + .handle_message(&message, state) + .await + .is_ok() + { + Ok(()) + } else if handshake_finalize + .handle_message(&message, state) + .await + .is_ok() + { + Ok(()) + } else { + Err(Error::from(HandlerUnhandled)) + } +} + +async fn initiate_handshake( + Message(initiation): Message, + State(connection_state): State, + _instance: AuthenticatedInstance, +) -> Result<(), HandshakeError> { + if !validate_version(&initiation.version) { + error!( + "Version compatibility failure! Our Version: {}, Their Version: {}", + Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), + initiation.version + ); + + connection_state + .send(HandshakeFinalize { success: false }) + .await + .map_err(|e| HandshakeError::SendError(e))?; + + Ok(()) + } else { + connection_state + .send(HandshakeFinalize { success: true }) + .await + .map_err(|e| HandshakeError::SendError(e))?; + + Ok(()) + } +} + +async fn handshake_response( + Message(response): Message, + State(connection_state): State, + _instance: AuthenticatedInstance, +) -> Result<(), HandshakeError> { + if !validate_version(&response.version) { + error!( + "Version compatibility failure! Our Version: {}, Their Version: {}", + Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), + response.version + ); + + connection_state + .send(HandshakeFinalize { success: false }) + .await + .map_err(|e| HandshakeError::SendError(e))?; + + Ok(()) + } else { + connection_state + .send(HandshakeFinalize { success: true }) + .await + .map_err(|e| HandshakeError::SendError(e))?; + + Ok(()) + } +} + +async fn handshake_finalize( + Message(finalize): Message, + State(connection_state): State, + _instance: AuthenticatedInstance, +) -> Result<(), HandshakeError> { + if !finalize.success { + error!("Error during handshake, aborting connection"); + return Err(Error::from(ConnectionError::Shutdown).into()); + } else { + connection_state.handshaked.store(true, Ordering::SeqCst); + + connection_state + .send(HandshakeFinalize { success: true }) + .await + .map_err(|e| HandshakeError::SendError(e))?; + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum HandshakeError { + #[error("version mismatch during handshake, ours: {0}, theirs: {1}")] + VersionMismatch(Version, Version), + #[error("while sending message: {0}")] + SendError(Error), + #[error("{0}")] + Other(#[from] Error), +} diff --git a/src/connection/repository.rs b/src/connection/repository.rs new file mode 100644 index 0000000..96a5eb6 --- /dev/null +++ b/src/connection/repository.rs @@ -0,0 +1,130 @@ +use anyhow::Error; + +use crate::{ + messages::repository::{ + CreateRepositoryRequest, RepositoryFileInspectRequest, RepositoryInfoRequest, + RepositoryIssueLabelsRequest, RepositoryIssuesCountRequest, RepositoryIssuesRequest, + RepositoryRequest, + }, + model::authenticated::{AuthenticatedUser, Message, MessageHandler, NetworkMessage, State}, +}; + +use super::{wrapper::ConnectionState, HandlerUnhandled}; + +pub async fn repository_handle( + message: &NetworkMessage, + state: &ConnectionState, +) -> Result<(), Error> { + if create_repository + .handle_message(&message, state) + .await + .is_ok() + { + Ok(()) + } else if repository_file_inspect + .handle_message(&message, state) + .await + .is_ok() + { + Ok(()) + } else if repository_info + .handle_message(&message, state) + .await + .is_ok() + { + Ok(()) + } else if issues_count.handle_message(&message, state).await.is_ok() { + Ok(()) + } else if issue_labels.handle_message(&message, state).await.is_ok() { + Ok(()) + } else if issues.handle_message(&message, state).await.is_ok() { + Ok(()) + } else { + Err(Error::from(HandlerUnhandled)) + } +} + +async fn create_repository( + Message(request): Message, + State(connection_state): State, + AuthenticatedUser(user): AuthenticatedUser, +) -> Result<(), RepositoryError> { + let mut repository_backend = connection_state.repository_backend.lock().await; + let response = repository_backend + .create_repository(&user, &request) + .await?; + + drop(repository_backend); + + connection_state.send(response).await?; + + Ok(()) +} + +async fn repository_file_inspect( + Message(request): Message, + State(connection_state): State, + user: Option, +) -> Result<(), RepositoryError> { + let user = user.map(|u| u.0); + + let mut repository_backend = connection_state.repository_backend.lock().await; + let response = repository_backend + .repository_file_inspect(user.as_ref(), &request) + .await?; + + drop(repository_backend); + + connection_state.send(response).await?; + + Ok(()) +} + +async fn repository_info( + Message(request): Message, + State(connection_state): State, + user: Option, +) -> Result<(), RepositoryError> { + let user = user.map(|u| u.0); + + let mut repository_backend = connection_state.repository_backend.lock().await; + let response = repository_backend + .repository_info(user.as_ref(), &request) + .await?; + + drop(repository_backend); + + connection_state.send(response).await?; + + Ok(()) +} + +async fn issues_count( + Message(request): Message, + State(connection_state): State, + user: Option, +) -> Result<(), RepositoryError> { + unimplemented!(); +} + +async fn issue_labels( + Message(request): Message, + State(connection_state): State, + user: Option, +) -> Result<(), RepositoryError> { + unimplemented!(); +} + +async fn issues( + Message(request): Message, + State(connection_state): State, + user: Option, +) -> Result<(), RepositoryError> { + unimplemented!(); +} + +#[derive(Debug, thiserror::Error)] +pub enum RepositoryError { + #[error("{0}")] + Other(#[from] Error), +} diff --git a/src/connection/user.rs b/src/connection/user.rs new file mode 100644 index 0000000..a6b0067 --- /dev/null +++ b/src/connection/user.rs @@ -0,0 +1,104 @@ +use anyhow::Error; + +use crate::{ + messages::user::{ + UserBioRequest, UserDisplayImageRequest, UserDisplayNameRequest, UserRepositoriesRequest, + UserRepositoriesResponse, + }, + model::authenticated::{Message, MessageHandler, NetworkMessage, State}, +}; + +use super::{wrapper::ConnectionState, HandlerUnhandled}; + +pub async fn user_handle(message: &NetworkMessage, state: &ConnectionState) -> Result<(), Error> { + if display_name.handle_message(&message, state).await.is_ok() { + Ok(()) + } else if display_image.handle_message(&message, state).await.is_ok() { + Ok(()) + } else if bio.handle_message(&message, state).await.is_ok() { + Ok(()) + } else { + Err(Error::from(HandlerUnhandled)) + } +} + +async fn display_name( + Message(request): Message, + State(connection_state): State, +) -> Result<(), UserError> { + let mut user_backend = connection_state.user_backend.lock().await; + let response = user_backend.display_name(request.clone()).await?; + + drop(user_backend); + + connection_state.send(response).await?; + + Ok(()) +} + +async fn display_image( + Message(request): Message, + State(connection_state): State, +) -> Result<(), UserError> { + let mut user_backend = connection_state.user_backend.lock().await; + let response = user_backend.display_image(request.clone()).await?; + + drop(user_backend); + + connection_state.send(response).await?; + + Ok(()) +} + +async fn bio( + Message(request): Message, + State(connection_state): State, +) -> Result<(), UserError> { + let mut user_backend = connection_state.user_backend.lock().await; + let response = user_backend.bio(request.clone()).await?; + + drop(user_backend); + + connection_state.send(response).await?; + + Ok(()) +} + +async fn repositories( + Message(request): Message, + State(connection_state): State, +) -> Result<(), UserError> { + let mut repository_backend = connection_state.repository_backend.lock().await; + + let repositories = repository_backend + .repositories_for_user(&request.user) + .await; + + let repositories = match repositories { + Ok(repositories) => repositories, + Err(err) => { + error!("Error handling request: {:?}", err); + return Ok(()); + } + }; + drop(repository_backend); + + let mut user_backend = connection_state.user_backend.lock().await; + let user_exists = user_backend.exists(&request.user).await; + + if repositories.is_empty() && !matches!(user_exists, Ok(true)) { + panic!() + } + + let response: UserRepositoriesResponse = UserRepositoriesResponse { repositories }; + + connection_state.send(response).await?; + + Ok(()) +} + +#[derive(Debug, thiserror::Error)] +pub enum UserError { + #[error("{0}")] + Other(#[from] Error), +} diff --git a/src/connection/wrapper.rs b/src/connection/wrapper.rs new file mode 100644 index 0000000..0e44212 --- /dev/null +++ b/src/connection/wrapper.rs @@ -0,0 +1,81 @@ +use std::{ + net::SocketAddr, + sync::{atomic::AtomicBool, Arc}, +}; + +use anyhow::Error; +use futures_util::SinkExt; +use serde::Serialize; +use tokio::{net::TcpStream, sync::Mutex}; +use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; + +use crate::{ + authentication::AuthenticationTokenGranter, + backend::{DiscoveryBackend, RepositoryBackend, UserBackend}, + connection::ConnectionError, + listener::Listeners, + model::instance::Instance, +}; + +use super::{connection_worker, Connections}; + +pub async fn connection_wrapper( + mut socket: WebSocketStream, + listeners: Arc>, + connections: Arc>, + repository_backend: Arc>, + user_backend: Arc>, + auth_granter: Arc>, + discovery_backend: Arc>, + addr: SocketAddr, +) { + let mut handshaked = false; + loop { + if let Err(e) = connection_worker( + &mut socket, + &mut handshaked, + &listeners, + &connections, + &repository_backend, + &user_backend, + &auth_granter, + &discovery_backend, + &addr, + ) + .await + { + error!("Error handling message: {:?}", e); + + if let ConnectionError::Shutdown = &e { + info!("Closing connection {}", addr); + return; + } + } + } +} + +#[derive(Clone)] +pub struct ConnectionState { + socket: Arc>>, + pub listeners: Arc>, + pub connections: Arc>, + pub repository_backend: Arc>, + pub user_backend: Arc>, + pub auth_granter: Arc>, + pub discovery_backend: Arc>, + pub addr: SocketAddr, + pub instance: Instance, + pub handshaked: Arc, +} + +impl ConnectionState { + pub async fn send(&self, message: T) -> Result<(), Error> { + self.socket + .lock() + .await + .send(Message::Binary(serde_json::to_vec(&message)?)) + .await?; + + Ok(()) + } +} diff --git a/src/handshake.rs b/src/handshake.rs index 0822f2c..15637ee 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -6,7 +6,6 @@ use crate::model::instance::Instance; /// Sent by the initiator of a new inter-daemon connection. #[derive(Clone, Serialize, Deserialize)] pub struct InitiateHandshake { - pub identity: Instance, pub version: Version, } diff --git a/src/lib.rs b/src/lib.rs index de968ca..74a8a25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,8 @@ pub mod listener; pub mod messages; pub mod model; +pub(crate) use std::error::Error as StdError; + #[macro_use] extern crate tracing; diff --git a/src/main.rs b/src/main.rs index 0c85f84..11a4e00 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,8 @@ use giterated_daemon::{ discovery::GiteratedDiscoveryProtocol, git::GitBackend, user::UserAuth, DiscoveryBackend, RepositoryBackend, UserBackend, }, - connection, listener, + connection::{self, wrapper::connection_wrapper}, + listener, model::instance::Instance, }; use listener::Listeners; @@ -108,7 +109,7 @@ async fn main() -> Result<(), Error> { info!("Websocket connection established with {}", address); let connection = RawConnection { - task: tokio::spawn(connection_worker( + task: tokio::spawn(connection_wrapper( connection, listeners.clone(), connections.clone(), diff --git a/src/messages/authentication.rs b/src/messages/authentication.rs index a1b4c58..0c560ae 100644 --- a/src/messages/authentication.rs +++ b/src/messages/authentication.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::model::authenticated::UserAuthenticationToken; + use super::InstanceAuthenticated; /// An authentication message. @@ -19,7 +21,7 @@ pub enum AuthenticationRequest { /// # Authentication /// - Instance Authentication /// - **ONLY ACCEPTED WHEN SAME-INSTANCE** - RegisterAccount(InstanceAuthenticated), + RegisterAccount(RegisterAccountRequest), /// An authentication token request. /// @@ -27,24 +29,24 @@ pub enum AuthenticationRequest { /// /// # Authentication /// - Instance Authentication - /// - **ONLY ACCEPTED WHEN SAME-INSTANCE** /// - Identifies the Instance to issue the token for /// # Authorization /// - Credentials ([`crate::backend::AuthBackend`]-based) /// - Identifies the User account to issue a token for /// - Decrypts user private key to issue to - AuthenticationToken(InstanceAuthenticated), + AuthenticationToken(AuthenticationTokenRequest), /// An authentication token extension request. /// /// # Authentication /// - Instance Authentication - /// - **ONLY ACCEPTED WHEN SAME-INSTANCE** /// - Identifies the Instance to issue the token for + /// - User Authentication + /// - Authenticates the validity of the token /// # Authorization /// - Token-based /// - Validates authorization using token's authenticity - TokenExtension(InstanceAuthenticated), + TokenExtension(TokenExtensionRequest), } #[derive(Clone, Serialize, Deserialize)] @@ -70,7 +72,6 @@ pub struct RegisterAccountResponse { /// See [`AuthenticationRequest::AuthenticationToken`]'s documentation. #[derive(Clone, Serialize, Deserialize)] pub struct AuthenticationTokenRequest { - pub secret_key: String, pub username: String, pub password: String, } @@ -83,8 +84,7 @@ pub struct AuthenticationTokenResponse { /// See [`AuthenticationRequest::TokenExtension`]'s documentation. #[derive(Clone, Serialize, Deserialize)] pub struct TokenExtensionRequest { - pub secret_key: String, - pub token: String, + pub token: UserAuthenticationToken, } #[derive(Clone, Serialize, Deserialize)] diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 5753f79..1af76c2 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -34,6 +34,15 @@ pub enum MessageKind { Authentication(AuthenticationMessage), Discovery(DiscoveryMessage), User(UserMessage), + Error(ErrorMessage), +} + +#[derive(Clone, Debug, Serialize, Deserialize, thiserror::Error)] +pub enum ErrorMessage { + #[error("user {0} doesn't exist or isn't valid in this context")] + InvalidUser(User), + #[error("internal error: shutdown")] + Shutdown, } /// An authenticated message, where the instance is authenticating diff --git a/src/model/authenticated.rs b/src/model/authenticated.rs new file mode 100644 index 0000000..0d10197 --- /dev/null +++ b/src/model/authenticated.rs @@ -0,0 +1,335 @@ +use std::{any::type_name, ops::Deref, pin::Pin, str::FromStr}; + +use anyhow::Error; +use futures_util::{future::BoxFuture, Future, FutureExt}; +use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; +use rsa::{pkcs1::DecodeRsaPublicKey, RsaPublicKey}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::{ + authentication::UserTokenMetadata, connection::wrapper::ConnectionState, messages::MessageKind, +}; + +use super::{instance::Instance, user::User}; + +#[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct Authenticated { + #[serde(flatten)] + source: Vec, + message_type: String, + #[serde(flatten)] + message: T, +} + +pub trait AuthenticationSourceProvider: Sized { + fn authenticate(self, payload: &Vec) -> AuthenticationSource; +} + +pub trait AuthenticationSourceProviders: Sized { + fn authenticate_all(self, payload: &Vec) -> Vec; +} + +impl AuthenticationSourceProviders for A +where + A: AuthenticationSourceProvider, +{ + fn authenticate_all(self, payload: &Vec) -> Vec { + vec![self.authenticate(payload)] + } +} + +impl AuthenticationSourceProviders for (A, B) +where + A: AuthenticationSourceProvider, + B: AuthenticationSourceProvider, +{ + fn authenticate_all(self, payload: &Vec) -> Vec { + let (first, second) = self; + + vec![first.authenticate(payload), second.authenticate(payload)] + } +} + +impl Authenticated { + pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self { + let message_payload = serde_json::to_vec(&message).unwrap(); + + let authentication = auth_sources.authenticate_all(&message_payload); + + Self { + source: authentication, + message_type: type_name::().to_string(), + message, + } + } +} + +mod verified {} + +pub struct UserAuthenticator { + pub user: User, + pub token: UserAuthenticationToken, +} + +impl AuthenticationSourceProvider for UserAuthenticator { + fn authenticate(self, payload: &Vec) -> AuthenticationSource { + AuthenticationSource::User { + user: self.user, + token: self.token, + } + } +} + +pub struct InstanceAuthenticator<'a> { + pub instance: Instance, + pub private_key: &'a str, +} + +impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { + fn authenticate(self, payload: &Vec) -> AuthenticationSource { + todo!() + } +} + +#[repr(transparent)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct UserAuthenticationToken(String); + +impl From for UserAuthenticationToken { + fn from(value: String) -> Self { + Self(value) + } +} + +impl ToString for UserAuthenticationToken { + fn to_string(&self) -> String { + self.0.clone() + } +} + +impl AsRef for UserAuthenticationToken { + fn as_ref(&self) -> &str { + &self.0 + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct InstanceSignature(Vec); + +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub enum AuthenticationSource { + User { + user: User, + token: UserAuthenticationToken, + }, + Instance { + instance: Instance, + signature: InstanceSignature, + }, +} + +pub struct NetworkMessage(pub Vec); + +impl Deref for NetworkMessage { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub struct AuthenticatedUser(pub User); + +#[derive(Debug, thiserror::Error)] +pub enum UserAuthenticationError { + #[error("user authentication missing")] + Missing, + // #[error("{0}")] + // InstanceAuthentication(#[from] Error), + #[error("user token was invalid")] + InvalidToken, + #[error("an error has occured")] + Other(#[from] Error), +} + +pub struct AuthenticatedInstance(Instance); + +impl AuthenticatedInstance { + pub fn inner(&self) -> &Instance { + &self.0 + } +} + +#[async_trait::async_trait] +pub trait FromMessage: Sized + Send + Sync { + async fn from_message(message: &NetworkMessage, state: &S) -> Result; +} + +#[async_trait::async_trait] +impl FromMessage for AuthenticatedUser { + async fn from_message( + network_message: &NetworkMessage, + state: &ConnectionState, + ) -> Result { + let message: Authenticated = + serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; + + let (auth_user, auth_token) = message + .source + .iter() + .filter_map(|auth| { + if let AuthenticationSource::User { user, token } = auth { + Some((user, token)) + } else { + None + } + }) + .next() + .ok_or_else(|| UserAuthenticationError::Missing)?; + + let authenticated_instance = + AuthenticatedInstance::from_message(network_message, state).await?; + + let public_key_raw = public_key(&auth_user.instance).await?; + let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap(); + + let data: TokenData = decode( + auth_token.as_ref(), + &verification_key, + &Validation::new(Algorithm::RS256), + ) + .unwrap(); + + if data.claims.user != *auth_user + || data.claims.generated_for != *authenticated_instance.inner() + { + Err(Error::from(UserAuthenticationError::InvalidToken)) + } else { + Ok(AuthenticatedUser(data.claims.user)) + } + } +} + +#[async_trait::async_trait] +impl FromMessage for AuthenticatedInstance { + async fn from_message( + message: &NetworkMessage, + state: &ConnectionState, + ) -> Result { + todo!() + } +} + +#[async_trait::async_trait] +impl FromMessage for MessageKind { + async fn from_message( + message: &NetworkMessage, + state: &ConnectionState, + ) -> Result { + todo!() + } +} + +#[async_trait::async_trait] +impl FromMessage for Option +where + T: FromMessage, + S: Send + Sync + 'static, +{ + async fn from_message(message: &NetworkMessage, state: &S) -> Result { + Ok(T::from_message(message, state).await.ok()) + } +} + +#[async_trait::async_trait] +pub trait MessageHandler { + async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result; +} +#[async_trait::async_trait] +impl MessageHandler<(T1,), S, R> for T +where + T: FnOnce(T1) -> F + Clone + Send + 'static, + F: Future> + Send, + T1: FromMessage + Send, + S: Send + Sync, + E: std::error::Error + Send + Sync + 'static, +{ + async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { + let value = T1::from_message(message, state).await?; + self(value).await.map_err(|e| Error::from(e)) + } +} + +#[async_trait::async_trait] +impl MessageHandler<(T1, T2), S, R> for T +where + T: FnOnce(T1, T2) -> F + Clone + Send + 'static, + F: Future> + Send, + T1: FromMessage + Send, + T2: FromMessage + Send, + S: Send + Sync, + E: std::error::Error + Send + Sync + 'static, +{ + async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { + let value = T1::from_message(message, state).await?; + let value_2 = T2::from_message(message, state).await?; + self(value, value_2).await.map_err(|e| Error::from(e)) + } +} + +#[async_trait::async_trait] +impl MessageHandler<(T1, T2, T3), S, R> for T +where + T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static, + F: Future> + Send, + T1: FromMessage + Send, + T2: FromMessage + Send, + T3: FromMessage + Send, + S: Send + Sync, + E: std::error::Error + Send + Sync + 'static, +{ + async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { + let value = T1::from_message(message, state).await?; + let value_2 = T2::from_message(message, state).await?; + let value_3 = T3::from_message(message, state).await?; + + self(value, value_2, value_3) + .await + .map_err(|e| Error::from(e)) + } +} + +pub struct State(pub T); + +#[async_trait::async_trait] +impl FromMessage for State +where + T: Clone + Send + Sync, +{ + async fn from_message(_: &NetworkMessage, state: &T) -> Result { + Ok(Self(state.clone())) + } +} + +// Temp +#[async_trait::async_trait] +impl FromMessage for Message +where + T: DeserializeOwned + Send + Sync + Serialize, + S: Clone + Send + Sync, +{ + async fn from_message(message: &NetworkMessage, _: &S) -> Result { + Ok(Message(serde_json::from_slice(&message)?)) + } +} + +pub struct Message(pub T); + +async fn public_key(instance: &Instance) -> Result { + let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) + .await? + .text() + .await?; + + Ok(key) +} diff --git a/src/model/mod.rs b/src/model/mod.rs index a4d3269..80fcc43 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -3,6 +3,7 @@ //! All network data model types that are not directly associated with //! individual requests or responses. +pub mod authenticated; pub mod discovery; pub mod instance; pub mod repository;