diff --git a/Cargo.lock b/Cargo.lock index f46eedf..e0180a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -163,6 +163,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -678,6 +687,7 @@ dependencies = [ "argon2", "async-trait", "base64 0.21.3", + "bincode", "chrono", "deadpool", "futures-util", @@ -711,6 +721,7 @@ dependencies = [ "argon2", "async-trait", "base64 0.21.3", + "bincode", "chrono", "futures-util", "git2", diff --git a/giterated-daemon/Cargo.toml b/giterated-daemon/Cargo.toml index 464ca0c..561bc39 100644 --- a/giterated-daemon/Cargo.toml +++ b/giterated-daemon/Cargo.toml @@ -26,6 +26,7 @@ tower = "*" giterated-models = { path = "../giterated-models" } giterated-api = { path = "../../giterated-api" } deadpool = "*" +bincode = "*" toml = { version = "0.7" } diff --git a/giterated-daemon/src/connection/forwarded.rs b/giterated-daemon/src/connection/forwarded.rs index 1ace443..01a3d20 100644 --- a/giterated-daemon/src/connection/forwarded.rs +++ b/giterated-daemon/src/connection/forwarded.rs @@ -1,12 +1,12 @@ use futures_util::{SinkExt, StreamExt}; use giterated_api::DaemonConnectionPool; -use giterated_models::{messages::error::ConnectionError, model::authenticated::Authenticated}; +use giterated_models::{messages::error::ConnectionError, model::authenticated::{Authenticated, AuthenticatedPayload}}; use serde::Serialize; use tokio_tungstenite::tungstenite::Message; -pub async fn wrap_forwarded( +pub async fn wrap_forwarded( pool: &DaemonConnectionPool, - message: Authenticated, + message: AuthenticatedPayload, ) -> Message { let connection = pool.get().await; diff --git a/giterated-daemon/src/connection/handshake.rs b/giterated-daemon/src/connection/handshake.rs index 111c613..e053a59 100644 --- a/giterated-daemon/src/connection/handshake.rs +++ b/giterated-daemon/src/connection/handshake.rs @@ -8,7 +8,7 @@ use semver::Version; use crate::{ connection::ConnectionError, - message::{Message, MessageHandler, NetworkMessage, State}, + message::{Message, MessageHandler, NetworkMessage, State, HandshakeMessage}, validate_version, version, }; @@ -42,9 +42,10 @@ pub async fn handshake_handle( } async fn initiate_handshake( - Message(initiation): Message, + HandshakeMessage(initiation): HandshakeMessage, State(connection_state): State, ) -> Result<(), HandshakeError> { + info!("meow!"); connection_state .send(HandshakeResponse { identity: connection_state.instance.clone(), @@ -81,7 +82,7 @@ async fn initiate_handshake( } async fn handshake_response( - Message(response): Message, + HandshakeMessage(initiation): HandshakeMessage, State(connection_state): State, ) -> Result<(), HandshakeError> { connection_state @@ -114,7 +115,7 @@ async fn handshake_response( } async fn handshake_finalize( - Message(finalize): Message, + HandshakeMessage(finalize): HandshakeMessage, State(connection_state): State, ) -> Result<(), HandshakeError> { connection_state.handshaked.store(true, Ordering::SeqCst); diff --git a/giterated-daemon/src/connection/wrapper.rs b/giterated-daemon/src/connection/wrapper.rs index 38b3a48..f8e0060 100644 --- a/giterated-daemon/src/connection/wrapper.rs +++ b/giterated-daemon/src/connection/wrapper.rs @@ -11,7 +11,7 @@ use anyhow::Error; use futures_util::{SinkExt, StreamExt}; use giterated_models::{ messages::error::ConnectionError, - model::{authenticated::Authenticated, instance::Instance}, + model::{authenticated::{Authenticated, AuthenticatedPayload}, instance::Instance}, }; use rsa::RsaPublicKey; use serde::Serialize; @@ -86,13 +86,14 @@ pub async fn connection_wrapper( let message = NetworkMessage(payload.clone()); if !handshaked { + info!("im foo baring"); if handshake_handle(&message, &connection_state).await.is_ok() { if connection_state.handshaked.load(Ordering::SeqCst) { handshaked = true; } } } else { - let raw = serde_json::from_slice::>(&payload).unwrap(); + let raw = serde_json::from_slice::(&payload).unwrap(); if let Some(target_instance) = &raw.target_instance { // Forward request @@ -116,7 +117,9 @@ pub async fn connection_wrapper( match authentication_handle(message_type, &message, &connection_state).await { Err(e) => { - let _ = connection_state.send(ConnectionError(e.to_string())).await; + let _ = connection_state + .send_raw(ConnectionError(e.to_string())) + .await; } Ok(true) => continue, Ok(false) => {} @@ -124,7 +127,9 @@ pub async fn connection_wrapper( match repository_handle(message_type, &message, &connection_state).await { Err(e) => { - let _ = connection_state.send(ConnectionError(e.to_string())).await; + let _ = connection_state + .send_raw(ConnectionError(e.to_string())) + .await; } Ok(true) => continue, Ok(false) => {} @@ -132,7 +137,9 @@ pub async fn connection_wrapper( match user_handle(message_type, &message, &connection_state).await { Err(e) => { - let _ = connection_state.send(ConnectionError(e.to_string())).await; + let _ = connection_state + .send_raw(ConnectionError(e.to_string())) + .await; } Ok(true) => continue, Ok(false) => {} @@ -140,7 +147,9 @@ pub async fn connection_wrapper( match authentication_handle(message_type, &message, &connection_state).await { Err(e) => { - let _ = connection_state.send(ConnectionError(e.to_string())).await; + let _ = connection_state + .send_raw(ConnectionError(e.to_string())) + .await; } Ok(true) => continue, Ok(false) => {} @@ -189,4 +198,16 @@ impl ConnectionState { Ok(()) } + + pub async fn send_raw(&self, message: T) -> Result<(), Error> { + let payload = serde_json::to_string(&message)?; + info!("Sending payload: {}", &payload); + self.socket + .lock() + .await + .send(Message::Binary(payload.into_bytes())) + .await?; + + Ok(()) + } } diff --git a/giterated-daemon/src/message.rs b/giterated-daemon/src/message.rs index 67d8205..987010d 100644 --- a/giterated-daemon/src/message.rs +++ b/giterated-daemon/src/message.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, ops::Deref}; use anyhow::Error; use futures_util::Future; use giterated_models::model::{ - authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata}, + authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata, AuthenticatedPayload}, instance::Instance, user::User, }; @@ -63,7 +63,7 @@ impl FromMessage for AuthenticatedUser { network_message: &NetworkMessage, state: &ConnectionState, ) -> Result { - let message: Authenticated> = + let message: AuthenticatedPayload = serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; let (auth_user, auth_token) = message @@ -108,9 +108,9 @@ impl FromMessage for AuthenticatedInstance { network_message: &NetworkMessage, state: &ConnectionState, ) -> Result { - let message: Authenticated> = + let message: AuthenticatedPayload = serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; - + let (instance, signature) = message .source .iter() @@ -137,7 +137,7 @@ impl FromMessage for AuthenticatedInstance { } else { drop(cached_keys); let mut cached_keys = state.cached_keys.write().await; - let key = public_key(instance).await?; + let key = public_key(&instance).await?; let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap(); cached_keys.insert(instance.clone(), public_key.clone()); public_key @@ -146,15 +146,8 @@ impl FromMessage for AuthenticatedInstance { let verifying_key: VerifyingKey = VerifyingKey::new(public_key); - let message_json = serde_json::to_vec(&message.message).unwrap(); - - info!( - "Verification against: {}", - std::str::from_utf8(&message_json).unwrap() - ); - verifying_key.verify( - &message_json, + &message.payload, &Signature::try_from(signature.as_ref()).unwrap(), )?; @@ -251,7 +244,8 @@ where S: Clone + Send + Sync, { async fn from_message(message: &NetworkMessage, _: &S) -> Result { - Ok(Message(serde_json::from_slice(&message)?)) + let payload: AuthenticatedPayload = serde_json::from_slice(&message)?; + Ok(Message(bincode::deserialize(&payload.payload)?)) } } @@ -265,3 +259,20 @@ async fn public_key(instance: &Instance) -> Result { Ok(key) } + +/// Handshake-specific message type. +/// +/// Uses basic serde_json-based deserialization to maintain the highest +/// level of compatibility across versions. +pub struct HandshakeMessage(pub T); + +#[async_trait::async_trait] +impl FromMessage for HandshakeMessage +where + T: DeserializeOwned + Send + Sync + Serialize, + S: Clone + Send + Sync, +{ + async fn from_message(message: &NetworkMessage, _: &S) -> Result { + Ok(HandshakeMessage(serde_json::from_slice(&message.0)?)) + } +} \ No newline at end of file diff --git a/giterated-models/Cargo.toml b/giterated-models/Cargo.toml index 4aff76d..11d80f7 100644 --- a/giterated-models/Cargo.toml +++ b/giterated-models/Cargo.toml @@ -23,6 +23,7 @@ argon2 = "*" aes-gcm = "0.10.2" semver = {version = "*", features = ["serde"]} tower = "*" +bincode = "*" toml = { version = "0.7" } diff --git a/giterated-models/src/model/authenticated.rs b/giterated-models/src/model/authenticated.rs index c5f9ae4..a4e7af6 100644 --- a/giterated-models/src/model/authenticated.rs +++ b/giterated-models/src/model/authenticated.rs @@ -1,4 +1,4 @@ -use std::any::type_name; +use std::{any::type_name, fmt::Debug}; use rsa::{ pkcs1::DecodeRsaPrivateKey, @@ -19,28 +19,52 @@ pub struct UserTokenMetadata { pub exp: u64, } +#[derive(Debug)] +pub struct Authenticated<'a, T: Serialize> { + pub target_instance: Option, + pub source: Vec<&'a dyn AuthenticationSourceProvider>, + pub message_type: String, + pub message: T, +} + #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct Authenticated { +pub struct AuthenticatedPayload { pub target_instance: Option, pub source: Vec, pub message_type: String, - #[serde(flatten)] - pub message: T, + pub payload: Vec, +} + +impl<'a, T: Serialize> From> for AuthenticatedPayload { + fn from(mut value: Authenticated<'a, T>) -> Self { + let payload = bincode::serialize(&value.message).unwrap(); + + AuthenticatedPayload { + target_instance: value.target_instance, + source: value + .source + .drain(..) + .map(|provider| provider.authenticate(&payload)) + .collect::>(), + message_type: value.message_type, + payload, + } + } } -pub trait AuthenticationSourceProvider: Sized { - fn authenticate(self, payload: &Vec) -> AuthenticationSource; +pub trait AuthenticationSourceProvider: Debug { + fn authenticate(&self, payload: &Vec) -> AuthenticationSource; } -pub trait AuthenticationSourceProviders: Sized { - fn authenticate_all(self, payload: &Vec) -> Vec; +pub trait AuthenticationSourceProviders: Debug { + fn authenticate_all(&self, payload: &Vec) -> Vec; } impl AuthenticationSourceProviders for A where A: AuthenticationSourceProvider, { - fn authenticate_all(self, payload: &Vec) -> Vec { + fn authenticate_all(&self, payload: &Vec) -> Vec { vec![self.authenticate(payload)] } } @@ -50,43 +74,26 @@ where A: AuthenticationSourceProvider, B: AuthenticationSourceProvider, { - fn authenticate_all(self, payload: &Vec) -> Vec { + fn authenticate_all(&self, payload: &Vec) -> Vec { let (first, second) = self; vec![first.authenticate(payload), second.authenticate(payload)] } } -impl Authenticated { - pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self { - let message_payload = serde_json::to_vec(&message).unwrap(); - - let authentication = auth_sources.authenticate_all(&message_payload); - +impl<'a, T: Serialize + Debug> Authenticated<'a, T> { + pub fn new(message: T) -> Self { Self { - source: authentication, + source: vec![], message_type: type_name::().to_string(), message, target_instance: None, } } - pub fn new_for( - instance: impl ToOwned, - message: T, - auth_sources: impl AuthenticationSourceProvider, - ) -> Self { - let message_payload = serde_json::to_vec(&message).unwrap(); - - info!( - "Verifying payload: {}", - std::str::from_utf8(&message_payload).unwrap() - ); - - let authentication = auth_sources.authenticate_all(&message_payload); - + pub fn new_for(instance: impl ToOwned, message: T) -> Self { Self { - source: authentication, + source: vec![], message_type: type_name::().to_string(), message, target_instance: Some(instance.to_owned()), @@ -102,7 +109,7 @@ impl Authenticated { } } - pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) { + pub fn append_authentication(&mut self, authentication: &'a dyn AuthenticationSourceProvider) { let message_payload = serde_json::to_vec(&self.message).unwrap(); info!( @@ -110,8 +117,11 @@ impl Authenticated { std::str::from_utf8(&message_payload).unwrap() ); - self.source - .push(authentication.authenticate(&message_payload)); + self.source.push(authentication); + } + + pub fn into_payload(self) -> AuthenticatedPayload { + self.into() } } @@ -124,22 +134,22 @@ pub struct UserAuthenticator { } impl AuthenticationSourceProvider for UserAuthenticator { - fn authenticate(self, _payload: &Vec) -> AuthenticationSource { + fn authenticate(&self, _payload: &Vec) -> AuthenticationSource { AuthenticationSource::User { - user: self.user, - token: self.token, + user: self.user.clone(), + token: self.token.clone(), } } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct InstanceAuthenticator<'a> { pub instance: Instance, pub private_key: &'a str, } impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { - fn authenticate(self, payload: &Vec) -> AuthenticationSource { + fn authenticate(&self, payload: &Vec) -> AuthenticationSource { let mut rng = rand::thread_rng(); let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap(); @@ -147,7 +157,7 @@ impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { let signature = signing_key.sign_with_rng(&mut rng, &payload); AuthenticationSource::Instance { - instance: self.instance, + instance: self.instance.clone(), // TODO: Actually parse signature from private key signature: InstanceSignature(signature.to_bytes().into_vec()), }