Add message forwarding
parent: tbd commit: e4fa992
Showing 9 changed files with 234 insertions and 3 deletions
Cargo.lock
@@ -347,6 +347,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" | ||
347 | 347 | checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" |
348 | 348 | |
349 | 349 | [[package]] |
350 | name = "deadpool" | |
351 | version = "0.9.5" | |
352 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
353 | checksum = "421fe0f90f2ab22016f32a9881be5134fdd71c65298917084b0c7477cbc3856e" | |
354 | dependencies = [ | |
355 | "async-trait", | |
356 | "deadpool-runtime", | |
357 | "num_cpus", | |
358 | "retain_mut", | |
359 | "tokio", | |
360 | ] | |
361 | ||
362 | [[package]] | |
363 | name = "deadpool-runtime" | |
364 | version = "0.1.2" | |
365 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
366 | checksum = "eaa37046cc0f6c3cc6090fbdbf73ef0b8ef4cfcc37f6befc0020f63e8cf121e1" | |
367 | ||
368 | [[package]] | |
350 | 369 | name = "der" |
351 | 370 | version = "0.7.8" |
352 | 371 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -628,6 +647,29 @@ dependencies = [ | ||
628 | 647 | ] |
629 | 648 | |
630 | 649 | [[package]] |
650 | name = "giterated-api" | |
651 | version = "0.1.0" | |
652 | dependencies = [ | |
653 | "anyhow", | |
654 | "async-trait", | |
655 | "chrono", | |
656 | "deadpool", | |
657 | "futures-util", | |
658 | "giterated-models", | |
659 | "jsonwebtoken", | |
660 | "rand", | |
661 | "reqwest", | |
662 | "semver", | |
663 | "serde", | |
664 | "serde_json", | |
665 | "thiserror", | |
666 | "tokio", | |
667 | "tokio-tungstenite", | |
668 | "tracing", | |
669 | "tracing-subscriber", | |
670 | ] | |
671 | ||
672 | [[package]] | |
631 | 673 | name = "giterated-daemon" |
632 | 674 | version = "0.0.6" |
633 | 675 | dependencies = [ |
@@ -637,8 +679,10 @@ dependencies = [ | ||
637 | 679 | "async-trait", |
638 | 680 | "base64 0.21.3", |
639 | 681 | "chrono", |
682 | "deadpool", | |
640 | 683 | "futures-util", |
641 | 684 | "git2", |
685 | "giterated-api", | |
642 | 686 | "giterated-models", |
643 | 687 | "jsonwebtoken", |
644 | 688 | "log", |
@@ -1521,6 +1565,12 @@ dependencies = [ | ||
1521 | 1565 | ] |
1522 | 1566 | |
1523 | 1567 | [[package]] |
1568 | name = "retain_mut" | |
1569 | version = "0.1.9" | |
1570 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
1571 | checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" | |
1572 | ||
1573 | [[package]] | |
1524 | 1574 | name = "ring" |
1525 | 1575 | version = "0.16.20" |
1526 | 1576 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -1578,6 +1628,49 @@ dependencies = [ | ||
1578 | 1628 | ] |
1579 | 1629 | |
1580 | 1630 | [[package]] |
1631 | name = "rustls" | |
1632 | version = "0.21.7" | |
1633 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
1634 | checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" | |
1635 | dependencies = [ | |
1636 | "log", | |
1637 | "ring", | |
1638 | "rustls-webpki", | |
1639 | "sct", | |
1640 | ] | |
1641 | ||
1642 | [[package]] | |
1643 | name = "rustls-native-certs" | |
1644 | version = "0.6.3" | |
1645 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
1646 | checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" | |
1647 | dependencies = [ | |
1648 | "openssl-probe", | |
1649 | "rustls-pemfile", | |
1650 | "schannel", | |
1651 | "security-framework", | |
1652 | ] | |
1653 | ||
1654 | [[package]] | |
1655 | name = "rustls-pemfile" | |
1656 | version = "1.0.3" | |
1657 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
1658 | checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" | |
1659 | dependencies = [ | |
1660 | "base64 0.21.3", | |
1661 | ] | |
1662 | ||
1663 | [[package]] | |
1664 | name = "rustls-webpki" | |
1665 | version = "0.101.4" | |
1666 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
1667 | checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" | |
1668 | dependencies = [ | |
1669 | "ring", | |
1670 | "untrusted", | |
1671 | ] | |
1672 | ||
1673 | [[package]] | |
1581 | 1674 | name = "ryu" |
1582 | 1675 | version = "1.0.15" |
1583 | 1676 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -1599,6 +1692,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" | ||
1599 | 1692 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" |
1600 | 1693 | |
1601 | 1694 | [[package]] |
1695 | name = "sct" | |
1696 | version = "0.7.0" | |
1697 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
1698 | checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" | |
1699 | dependencies = [ | |
1700 | "ring", | |
1701 | "untrusted", | |
1702 | ] | |
1703 | ||
1704 | [[package]] | |
1602 | 1705 | name = "security-framework" |
1603 | 1706 | version = "2.9.2" |
1604 | 1707 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -2191,6 +2294,16 @@ dependencies = [ | ||
2191 | 2294 | ] |
2192 | 2295 | |
2193 | 2296 | [[package]] |
2297 | name = "tokio-rustls" | |
2298 | version = "0.24.1" | |
2299 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
2300 | checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" | |
2301 | dependencies = [ | |
2302 | "rustls", | |
2303 | "tokio", | |
2304 | ] | |
2305 | ||
2306 | [[package]] | |
2194 | 2307 | name = "tokio-stream" |
2195 | 2308 | version = "0.1.14" |
2196 | 2309 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -2209,7 +2322,10 @@ checksum = "2b2dbec703c26b00d74844519606ef15d09a7d6857860f84ad223dec002ddea2" | ||
2209 | 2322 | dependencies = [ |
2210 | 2323 | "futures-util", |
2211 | 2324 | "log", |
2325 | "rustls", | |
2326 | "rustls-native-certs", | |
2212 | 2327 | "tokio", |
2328 | "tokio-rustls", | |
2213 | 2329 | "tungstenite", |
2214 | 2330 | ] |
2215 | 2331 | |
@@ -2362,6 +2478,7 @@ dependencies = [ | ||
2362 | 2478 | "httparse", |
2363 | 2479 | "log", |
2364 | 2480 | "rand", |
2481 | "rustls", | |
2365 | 2482 | "sha1", |
2366 | 2483 | "thiserror", |
2367 | 2484 | "url", |
giterated-daemon/Cargo.toml
@@ -24,6 +24,8 @@ aes-gcm = "0.10.2" | ||
24 | 24 | semver = {version = "*", features = ["serde"]} |
25 | 25 | tower = "*" |
26 | 26 | giterated-models = { path = "../giterated-models" } |
27 | giterated-api = { path = "../../giterated-api" } | |
28 | deadpool = "*" | |
27 | 29 | |
28 | 30 | toml = { version = "0.7" } |
29 | 31 |
giterated-daemon/src/connection.rs
@@ -1,4 +1,5 @@ | ||
1 | 1 | pub mod authentication; |
2 | pub mod forwarded; | |
2 | 3 | pub mod handshake; |
3 | 4 | pub mod repository; |
4 | 5 | pub mod user; |
giterated-daemon/src/connection/forwarded.rs
@@ -0,0 +1,61 @@ | ||
1 | use futures_util::{SinkExt, StreamExt}; | |
2 | use giterated_api::DaemonConnectionPool; | |
3 | use giterated_models::{messages::error::ConnectionError, model::authenticated::Authenticated}; | |
4 | use serde::Serialize; | |
5 | use tokio_tungstenite::tungstenite::Message; | |
6 | ||
7 | pub async fn wrap_forwarded<T: Serialize>( | |
8 | pool: &DaemonConnectionPool, | |
9 | message: Authenticated<T>, | |
10 | ) -> Message { | |
11 | let connection = pool.get().await; | |
12 | ||
13 | let mut connection = match connection { | |
14 | Ok(connection) => connection, | |
15 | Err(e) => { | |
16 | return Message::Binary(serde_json::to_vec(&ConnectionError(e.to_string())).unwrap()) | |
17 | } | |
18 | }; | |
19 | ||
20 | let send_result = connection | |
21 | .send(Message::Binary(serde_json::to_vec(&message).unwrap())) | |
22 | .await; | |
23 | ||
24 | if let Err(e) = send_result { | |
25 | return Message::Binary(serde_json::to_vec(&ConnectionError(e.to_string())).unwrap()); | |
26 | } | |
27 | ||
28 | loop { | |
29 | let message = connection.next().await; | |
30 | ||
31 | match message { | |
32 | Some(Ok(message)) => { | |
33 | match message { | |
34 | Message::Binary(payload) => return Message::Binary(payload), | |
35 | Message::Ping(_) => { | |
36 | let _ = connection.send(Message::Pong(vec![])).await; | |
37 | continue; | |
38 | } | |
39 | Message::Close(_) => { | |
40 | return Message::Binary( | |
41 | String::from("The instance you wanted to talk to hung up on me :(") | |
42 | .into_bytes(), | |
43 | ) | |
44 | } | |
45 | _ => continue, | |
46 | }; | |
47 | } | |
48 | Some(Err(e)) => { | |
49 | return Message::Binary( | |
50 | serde_json::to_vec(&ConnectionError(e.to_string())).unwrap(), | |
51 | ) | |
52 | } | |
53 | _ => { | |
54 | info!("Unhandled"); | |
55 | continue; | |
56 | } | |
57 | } | |
58 | } | |
59 | ||
60 | todo!() | |
61 | } |
giterated-daemon/src/connection/wrapper.rs
@@ -9,7 +9,10 @@ use std::{ | ||
9 | 9 | |
10 | 10 | use anyhow::Error; |
11 | 11 | use futures_util::{SinkExt, StreamExt}; |
12 | use giterated_models::{messages::error::ConnectionError, model::instance::Instance}; | |
12 | use giterated_models::{ | |
13 | messages::error::ConnectionError, | |
14 | model::{authenticated::Authenticated, instance::Instance}, | |
15 | }; | |
13 | 16 | use rsa::RsaPublicKey; |
14 | 17 | use serde::Serialize; |
15 | 18 | use serde_json::Value; |
@@ -22,6 +25,8 @@ use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; | ||
22 | 25 | use crate::{ |
23 | 26 | authentication::AuthenticationTokenGranter, |
24 | 27 | backend::{RepositoryBackend, UserBackend}, |
28 | connection::forwarded::wrap_forwarded, | |
29 | federation::connections::InstanceConnections, | |
25 | 30 | message::NetworkMessage, |
26 | 31 | }; |
27 | 32 | |
@@ -38,6 +43,7 @@ pub async fn connection_wrapper( | ||
38 | 43 | auth_granter: Arc<Mutex<AuthenticationTokenGranter>>, |
39 | 44 | addr: SocketAddr, |
40 | 45 | instance: impl ToOwned<Owned = Instance>, |
46 | instance_connections: Arc<Mutex<InstanceConnections>>, | |
41 | 47 | ) { |
42 | 48 | let connection_state = ConnectionState { |
43 | 49 | socket: Arc::new(Mutex::new(socket)), |
@@ -81,8 +87,24 @@ pub async fn connection_wrapper( | ||
81 | 87 | } |
82 | 88 | } |
83 | 89 | } else { |
84 | let raw = serde_json::from_slice::<Value>(&payload).unwrap(); | |
85 | let message_type = raw.get("message_type").unwrap().as_str().unwrap(); | |
90 | let raw = serde_json::from_slice::<Authenticated<Value>>(&payload).unwrap(); | |
91 | ||
92 | if let Some(target_instance) = &raw.target_instance { | |
93 | // Forward request | |
94 | let mut instance_connections = instance_connections.lock().await; | |
95 | let pool = instance_connections.get_or_open(&target_instance).unwrap(); | |
96 | let pool_clone = pool.clone(); | |
97 | drop(pool); | |
98 | ||
99 | let result = wrap_forwarded(&pool_clone, raw).await; | |
100 | ||
101 | let mut socket = connection_state.socket.lock().await; | |
102 | let _ = socket.send(result).await; | |
103 | ||
104 | continue; | |
105 | } | |
106 | ||
107 | let message_type = &raw.message_type; | |
86 | 108 | |
87 | 109 | match authentication_handle(message_type, &message, &connection_state).await { |
88 | 110 | Err(e) => { |
giterated-daemon/src/federation/connections.rs
@@ -0,0 +1,23 @@ | ||
1 | use std::collections::HashMap; | |
2 | ||
3 | use anyhow::Error; | |
4 | use giterated_api::DaemonConnectionPool; | |
5 | use giterated_models::model::instance::Instance; | |
6 | ||
7 | #[derive(Default)] | |
8 | pub struct InstanceConnections { | |
9 | pools: HashMap<Instance, DaemonConnectionPool>, | |
10 | } | |
11 | ||
12 | impl InstanceConnections { | |
13 | pub fn get_or_open(&mut self, instance: &Instance) -> Result<DaemonConnectionPool, Error> { | |
14 | if let Some(pool) = self.pools.get(instance) { | |
15 | Ok(pool.clone()) | |
16 | } else { | |
17 | let pool = DaemonConnectionPool::connect(instance.clone()).unwrap(); | |
18 | self.pools.insert(instance.clone(), pool.clone()); | |
19 | ||
20 | Ok(pool) | |
21 | } | |
22 | } | |
23 | } |
giterated-daemon/src/federation/mod.rs
@@ -0,0 +1 @@ | ||
1 | pub mod connections; |
giterated-daemon/src/lib.rs
@@ -5,6 +5,7 @@ use semver::{Version, VersionReq}; | ||
5 | 5 | pub mod authentication; |
6 | 6 | pub mod backend; |
7 | 7 | pub mod connection; |
8 | pub mod federation; | |
8 | 9 | pub mod message; |
9 | 10 | |
10 | 11 | #[macro_use] |
giterated-daemon/src/main.rs
@@ -4,6 +4,7 @@ use giterated_daemon::{ | ||
4 | 4 | authentication::AuthenticationTokenGranter, |
5 | 5 | backend::{git::GitBackend, user::UserAuth, RepositoryBackend, UserBackend}, |
6 | 6 | connection::{self, wrapper::connection_wrapper}, |
7 | federation::connections::InstanceConnections, | |
7 | 8 | }; |
8 | 9 | use giterated_models::model::instance::Instance; |
9 | 10 | use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool}; |
@@ -25,6 +26,7 @@ async fn main() -> Result<(), Error> { | ||
25 | 26 | tracing_subscriber::fmt::init(); |
26 | 27 | let mut listener = TcpListener::bind("0.0.0.0:7270").await?; |
27 | 28 | let connections: Arc<Mutex<Connections>> = Arc::default(); |
29 | let instance_connections: Arc<Mutex<InstanceConnections>> = Arc::default(); | |
28 | 30 | let config: Table = { |
29 | 31 | let mut file = File::open("Giterated.toml").await?; |
30 | 32 | let mut text = String::new(); |
@@ -106,6 +108,7 @@ async fn main() -> Result<(), Error> { | ||
106 | 108 | token_granter.clone(), |
107 | 109 | address, |
108 | 110 | Instance::from_str("giterated.dev").unwrap(), |
111 | instance_connections.clone(), | |
109 | 112 | )), |
110 | 113 | }; |
111 | 114 |