diff --git a/src/backend/user.rs b/src/backend/user.rs index 3fc2bca..471712b 100644 --- a/src/backend/user.rs +++ b/src/backend/user.rs @@ -95,8 +95,7 @@ impl UserBackend for UserAuth { request.user.username ) .fetch_one(&self.pg_pool.clone()) - .await - .unwrap(); + .await?; Ok(UserBioResponse { bio: db_row.bio }) } diff --git a/src/connection/authentication.rs b/src/connection/authentication.rs index e194703..6675291 100644 --- a/src/connection/authentication.rs +++ b/src/connection/authentication.rs @@ -8,30 +8,36 @@ use crate::model::authenticated::{AuthenticatedInstance, NetworkMessage, State}; use crate::model::authenticated::{Message, MessageHandler}; use super::wrapper::ConnectionState; -use super::HandlerUnhandled; + pub async fn authentication_handle( message_type: &str, message: &NetworkMessage, state: &ConnectionState, -) -> Result<(), Error> { +) -> Result { match message_type { "&giterated_daemon::messages::authentication::RegisterAccountRequest" => { register_account_request .handle_message(&message, state) - .await + .await?; + + Ok(true) } "&giterated_daemon::messages::authentication::AuthenticationTokenRequest" => { authentication_token_request .handle_message(&message, state) - .await + .await?; + + Ok(true) } "&giterated_daemon::messages::authentication::TokenExtensionRequest" => { token_extension_request .handle_message(&message, state) - .await + .await?; + + Ok(true) } - _ => Err(Error::from(HandlerUnhandled)), + _ => Ok(false), } } diff --git a/src/connection/handshake.rs b/src/connection/handshake.rs index 1368ded..c0bb8ab 100644 --- a/src/connection/handshake.rs +++ b/src/connection/handshake.rs @@ -6,7 +6,7 @@ use semver::Version; use crate::{ connection::ConnectionError, messages::handshake::{HandshakeFinalize, HandshakeResponse, InitiateHandshake}, - model::authenticated::{AuthenticatedInstance, Message, MessageHandler, NetworkMessage, State}, + model::authenticated::{Message, MessageHandler, NetworkMessage, State}, validate_version, version, }; diff --git a/src/connection/repository.rs b/src/connection/repository.rs index 117d3ea..c9a6f2c 100644 --- a/src/connection/repository.rs +++ b/src/connection/repository.rs @@ -9,35 +9,47 @@ use crate::{ model::authenticated::{AuthenticatedUser, Message, MessageHandler, NetworkMessage, State}, }; -use super::{wrapper::ConnectionState, HandlerUnhandled}; +use super::{wrapper::ConnectionState}; pub async fn repository_handle( message_type: &str, message: &NetworkMessage, state: &ConnectionState, -) -> Result<(), Error> { +) -> Result { match message_type { "&giterated_daemon::messages::repository::RepositoryCreateRequest" => { - create_repository.handle_message(&message, state).await + create_repository.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::repository::RepositoryFileInspectRequest" => { repository_file_inspect .handle_message(&message, state) - .await + .await?; + + Ok(true) } "&giterated_daemon::messages::repository::RepositoryInfoRequest" => { - repository_info.handle_message(&message, state).await + repository_info.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::repository::RepositoryIssuesCountRequest" => { - issues_count.handle_message(&message, state).await + issues_count.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::repository::RepositoryIssueLabelsRequest" => { - issue_labels.handle_message(&message, state).await + issue_labels.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::repository::RepositoryIssuesRequest" => { - issues.handle_message(&message, state).await + issues.handle_message(&message, state).await?; + + Ok(true) } - _ => Err(Error::from(HandlerUnhandled)), + _ => Ok(false), } } diff --git a/src/connection/user.rs b/src/connection/user.rs index fee26bd..7133bc8 100644 --- a/src/connection/user.rs +++ b/src/connection/user.rs @@ -1,5 +1,5 @@ use anyhow::Error; -use serde_json::Value; + use crate::model::authenticated::AuthenticatedUser; use crate::model::user::User; @@ -11,27 +11,35 @@ use crate::{ model::authenticated::{Message, MessageHandler, NetworkMessage, State}, }; -use super::{wrapper::ConnectionState, HandlerUnhandled}; +use super::{wrapper::ConnectionState}; pub async fn user_handle( message_type: &str, message: &NetworkMessage, state: &ConnectionState, -) -> Result<(), Error> { +) -> Result { match message_type { "&giterated_daemon::messages::user::UserDisplayNameRequest" => { - display_name.handle_message(&message, state).await + display_name.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::user::UserDisplayImageRequest" => { - display_image.handle_message(&message, state).await + display_image.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::user::UserBioRequest" => { - bio.handle_message(&message, state).await + bio.handle_message(&message, state).await?; + + Ok(true) } "&giterated_daemon::messages::user::UserRepositoriesRequest" => { - repositories.handle_message(&message, state).await + repositories.handle_message(&message, state).await?; + + Ok(true) } - _ => Err(Error::from(HandlerUnhandled)), + _ => Ok(false), } } diff --git a/src/connection/wrapper.rs b/src/connection/wrapper.rs index a6e5237..f544091 100644 --- a/src/connection/wrapper.rs +++ b/src/connection/wrapper.rs @@ -16,6 +16,7 @@ use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use crate::{ authentication::AuthenticationTokenGranter, backend::{RepositoryBackend, UserBackend}, + messages::error::ConnectionError, model::{authenticated::NetworkMessage, instance::Instance}, }; @@ -77,28 +78,42 @@ pub async fn connection_wrapper( let raw = serde_json::from_slice::(&payload).unwrap(); let message_type = raw.get("message_type").unwrap().as_str().unwrap(); - if authentication_handle(message_type, &message, &connection_state) - .await - .is_ok() - { - continue; - } else if repository_handle(message_type, &message, &connection_state) - .await - .is_ok() - { - continue; - } else if user_handle(message_type, &message, &connection_state) - .await - .is_ok() - { - continue; - } else { - error!( - "Message completely unhandled: {}", - std::str::from_utf8(&payload).unwrap() - ); - continue; + match authentication_handle(message_type, &message, &connection_state).await { + Err(e) => { + let _ = connection_state.send(ConnectionError(e.to_string())).await; + } + Ok(true) => continue, + Ok(false) => {} + } + + match repository_handle(message_type, &message, &connection_state).await { + Err(e) => { + let _ = connection_state.send(ConnectionError(e.to_string())).await; + } + Ok(true) => continue, + Ok(false) => {} } + + match user_handle(message_type, &message, &connection_state).await { + Err(e) => { + let _ = connection_state.send(ConnectionError(e.to_string())).await; + } + Ok(true) => continue, + Ok(false) => {} + } + + match authentication_handle(message_type, &message, &connection_state).await { + Err(e) => { + let _ = connection_state.send(ConnectionError(e.to_string())).await; + } + Ok(true) => continue, + Ok(false) => {} + } + + error!( + "Message completely unhandled: {}", + std::str::from_utf8(&payload).unwrap() + ); } } Some(Err(e)) => { diff --git a/src/messages/error.rs b/src/messages/error.rs new file mode 100644 index 0000000..4dcbc82 --- /dev/null +++ b/src/messages/error.rs @@ -0,0 +1,5 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, thiserror::Error)] +#[error("error from connection: {0}")] +pub struct ConnectionError(pub String); diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 6e31b47..bc9f449 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -5,6 +5,7 @@ use crate::model::user::User; pub mod authentication; pub mod discovery; +pub mod error; pub mod handshake; pub mod issues; pub mod repository; diff --git a/src/model/authenticated.rs b/src/model/authenticated.rs index f5f8b5f..b865cc3 100644 --- a/src/model/authenticated.rs +++ b/src/model/authenticated.rs @@ -239,7 +239,7 @@ impl FromMessage for AuthenticatedInstance { let message: Authenticated> = serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; - let (instance, signature) = message + let (instance, _signature) = message .source .iter() .filter_map(|auth| {