diff --git a/Cargo.lock b/Cargo.lock index 895f227..f46eedf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,6 +347,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" [[package]] +name = "deadpool" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "421fe0f90f2ab22016f32a9881be5134fdd71c65298917084b0c7477cbc3856e" +dependencies = [ + "async-trait", + "deadpool-runtime", + "num_cpus", + "retain_mut", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaa37046cc0f6c3cc6090fbdbf73ef0b8ef4cfcc37f6befc0020f63e8cf121e1" + +[[package]] name = "der" version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -628,6 +647,29 @@ dependencies = [ ] [[package]] +name = "giterated-api" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "chrono", + "deadpool", + "futures-util", + "giterated-models", + "jsonwebtoken", + "rand", + "reqwest", + "semver", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-tungstenite", + "tracing", + "tracing-subscriber", +] + +[[package]] name = "giterated-daemon" version = "0.0.6" dependencies = [ @@ -637,8 +679,10 @@ dependencies = [ "async-trait", "base64 0.21.3", "chrono", + "deadpool", "futures-util", "git2", + "giterated-api", "giterated-models", "jsonwebtoken", "log", @@ -1521,6 +1565,12 @@ dependencies = [ ] [[package]] +name = "retain_mut" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" + +[[package]] name = "ring" version = "0.16.20" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1578,6 +1628,49 @@ dependencies = [ ] [[package]] +name = "rustls" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +dependencies = [ + "base64 0.21.3", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] name = "ryu" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1599,6 +1692,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] name = "security-framework" version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2191,6 +2294,16 @@ dependencies = [ ] [[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] name = "tokio-stream" version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2209,7 +2322,10 @@ checksum = "2b2dbec703c26b00d74844519606ef15d09a7d6857860f84ad223dec002ddea2" dependencies = [ "futures-util", "log", + "rustls", + "rustls-native-certs", "tokio", + "tokio-rustls", "tungstenite", ] @@ -2362,6 +2478,7 @@ dependencies = [ "httparse", "log", "rand", + "rustls", "sha1", "thiserror", "url", diff --git a/giterated-daemon/Cargo.toml b/giterated-daemon/Cargo.toml index a2c761b..464ca0c 100644 --- a/giterated-daemon/Cargo.toml +++ b/giterated-daemon/Cargo.toml @@ -24,6 +24,8 @@ aes-gcm = "0.10.2" semver = {version = "*", features = ["serde"]} tower = "*" giterated-models = { path = "../giterated-models" } +giterated-api = { path = "../../giterated-api" } +deadpool = "*" toml = { version = "0.7" } diff --git a/giterated-daemon/src/connection.rs b/giterated-daemon/src/connection.rs index cb1f370..2d4cd51 100644 --- a/giterated-daemon/src/connection.rs +++ b/giterated-daemon/src/connection.rs @@ -1,4 +1,5 @@ pub mod authentication; +pub mod forwarded; pub mod handshake; pub mod repository; pub mod user; diff --git a/giterated-daemon/src/connection/forwarded.rs b/giterated-daemon/src/connection/forwarded.rs new file mode 100644 index 0000000..1ace443 --- /dev/null +++ b/giterated-daemon/src/connection/forwarded.rs @@ -0,0 +1,61 @@ +use futures_util::{SinkExt, StreamExt}; +use giterated_api::DaemonConnectionPool; +use giterated_models::{messages::error::ConnectionError, model::authenticated::Authenticated}; +use serde::Serialize; +use tokio_tungstenite::tungstenite::Message; + +pub async fn wrap_forwarded( + pool: &DaemonConnectionPool, + message: Authenticated, +) -> Message { + let connection = pool.get().await; + + let mut connection = match connection { + Ok(connection) => connection, + Err(e) => { + return Message::Binary(serde_json::to_vec(&ConnectionError(e.to_string())).unwrap()) + } + }; + + let send_result = connection + .send(Message::Binary(serde_json::to_vec(&message).unwrap())) + .await; + + if let Err(e) = send_result { + return Message::Binary(serde_json::to_vec(&ConnectionError(e.to_string())).unwrap()); + } + + loop { + let message = connection.next().await; + + match message { + Some(Ok(message)) => { + match message { + Message::Binary(payload) => return Message::Binary(payload), + Message::Ping(_) => { + let _ = connection.send(Message::Pong(vec![])).await; + continue; + } + Message::Close(_) => { + return Message::Binary( + String::from("The instance you wanted to talk to hung up on me :(") + .into_bytes(), + ) + } + _ => continue, + }; + } + Some(Err(e)) => { + return Message::Binary( + serde_json::to_vec(&ConnectionError(e.to_string())).unwrap(), + ) + } + _ => { + info!("Unhandled"); + continue; + } + } + } + + todo!() +} diff --git a/giterated-daemon/src/connection/wrapper.rs b/giterated-daemon/src/connection/wrapper.rs index 4954ee8..20cc785 100644 --- a/giterated-daemon/src/connection/wrapper.rs +++ b/giterated-daemon/src/connection/wrapper.rs @@ -9,7 +9,10 @@ use std::{ use anyhow::Error; use futures_util::{SinkExt, StreamExt}; -use giterated_models::{messages::error::ConnectionError, model::instance::Instance}; +use giterated_models::{ + messages::error::ConnectionError, + model::{authenticated::Authenticated, instance::Instance}, +}; use rsa::RsaPublicKey; use serde::Serialize; use serde_json::Value; @@ -22,6 +25,8 @@ use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use crate::{ authentication::AuthenticationTokenGranter, backend::{RepositoryBackend, UserBackend}, + connection::forwarded::wrap_forwarded, + federation::connections::InstanceConnections, message::NetworkMessage, }; @@ -38,6 +43,7 @@ pub async fn connection_wrapper( auth_granter: Arc>, addr: SocketAddr, instance: impl ToOwned, + instance_connections: Arc>, ) { let connection_state = ConnectionState { socket: Arc::new(Mutex::new(socket)), @@ -81,8 +87,24 @@ pub async fn connection_wrapper( } } } else { - let raw = serde_json::from_slice::(&payload).unwrap(); - let message_type = raw.get("message_type").unwrap().as_str().unwrap(); + let raw = serde_json::from_slice::>(&payload).unwrap(); + + if let Some(target_instance) = &raw.target_instance { + // Forward request + let mut instance_connections = instance_connections.lock().await; + let pool = instance_connections.get_or_open(&target_instance).unwrap(); + let pool_clone = pool.clone(); + drop(pool); + + let result = wrap_forwarded(&pool_clone, raw).await; + + let mut socket = connection_state.socket.lock().await; + let _ = socket.send(result).await; + + continue; + } + + let message_type = &raw.message_type; match authentication_handle(message_type, &message, &connection_state).await { Err(e) => { diff --git a/giterated-daemon/src/federation/connections.rs b/giterated-daemon/src/federation/connections.rs new file mode 100644 index 0000000..64dc6e1 --- /dev/null +++ b/giterated-daemon/src/federation/connections.rs @@ -0,0 +1,23 @@ +use std::collections::HashMap; + +use anyhow::Error; +use giterated_api::DaemonConnectionPool; +use giterated_models::model::instance::Instance; + +#[derive(Default)] +pub struct InstanceConnections { + pools: HashMap, +} + +impl InstanceConnections { + pub fn get_or_open(&mut self, instance: &Instance) -> Result { + if let Some(pool) = self.pools.get(instance) { + Ok(pool.clone()) + } else { + let pool = DaemonConnectionPool::connect(instance.clone()).unwrap(); + self.pools.insert(instance.clone(), pool.clone()); + + Ok(pool) + } + } +} diff --git a/giterated-daemon/src/federation/mod.rs b/giterated-daemon/src/federation/mod.rs new file mode 100644 index 0000000..a67f22c --- /dev/null +++ b/giterated-daemon/src/federation/mod.rs @@ -0,0 +1 @@ +pub mod connections; diff --git a/giterated-daemon/src/lib.rs b/giterated-daemon/src/lib.rs index e70ce9f..05494fb 100644 --- a/giterated-daemon/src/lib.rs +++ b/giterated-daemon/src/lib.rs @@ -5,6 +5,7 @@ use semver::{Version, VersionReq}; pub mod authentication; pub mod backend; pub mod connection; +pub mod federation; pub mod message; #[macro_use] diff --git a/giterated-daemon/src/main.rs b/giterated-daemon/src/main.rs index dfcb56f..5a3f62e 100644 --- a/giterated-daemon/src/main.rs +++ b/giterated-daemon/src/main.rs @@ -4,6 +4,7 @@ use giterated_daemon::{ authentication::AuthenticationTokenGranter, backend::{git::GitBackend, user::UserAuth, RepositoryBackend, UserBackend}, connection::{self, wrapper::connection_wrapper}, + federation::connections::InstanceConnections, }; use giterated_models::model::instance::Instance; use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool}; @@ -25,6 +26,7 @@ async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); let mut listener = TcpListener::bind("0.0.0.0:7270").await?; let connections: Arc> = Arc::default(); + let instance_connections: Arc> = Arc::default(); let config: Table = { let mut file = File::open("Giterated.toml").await?; let mut text = String::new(); @@ -106,6 +108,7 @@ async fn main() -> Result<(), Error> { token_granter.clone(), address, Instance::from_str("giterated.dev").unwrap(), + instance_connections.clone(), )), };