JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add message forwarding

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨e4fa992

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨6377⁩ bytes
Raw
1 use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 };
9
10 use anyhow::Error;
11 use futures_util::{SinkExt, StreamExt};
12 use giterated_models::{
13 messages::error::ConnectionError,
14 model::{authenticated::Authenticated, instance::Instance},
15 };
16 use rsa::RsaPublicKey;
17 use serde::Serialize;
18 use serde_json::Value;
19 use tokio::{
20 net::TcpStream,
21 sync::{Mutex, RwLock},
22 };
23 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
24
25 use crate::{
26 authentication::AuthenticationTokenGranter,
27 backend::{RepositoryBackend, UserBackend},
28 connection::forwarded::wrap_forwarded,
29 federation::connections::InstanceConnections,
30 message::NetworkMessage,
31 };
32
33 use super::{
34 authentication::authentication_handle, handshake::handshake_handle,
35 repository::repository_handle, user::user_handle, Connections,
36 };
37
38 pub async fn connection_wrapper(
39 socket: WebSocketStream<TcpStream>,
40 connections: Arc<Mutex<Connections>>,
41 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
42 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
43 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
44 addr: SocketAddr,
45 instance: impl ToOwned<Owned = Instance>,
46 instance_connections: Arc<Mutex<InstanceConnections>>,
47 ) {
48 let connection_state = ConnectionState {
49 socket: Arc::new(Mutex::new(socket)),
50 connections,
51 repository_backend,
52 user_backend,
53 auth_granter,
54 addr,
55 instance: instance.to_owned(),
56 handshaked: Arc::new(AtomicBool::new(false)),
57 cached_keys: Arc::default(),
58 };
59
60 let mut handshaked = false;
61
62 loop {
63 let mut socket = connection_state.socket.lock().await;
64 let message = socket.next().await;
65 drop(socket);
66
67 match message {
68 Some(Ok(message)) => {
69 let payload = match message {
70 Message::Binary(payload) => payload,
71 Message::Ping(_) => {
72 let mut socket = connection_state.socket.lock().await;
73 let _ = socket.send(Message::Pong(vec![])).await;
74 drop(socket);
75 continue;
76 }
77 Message::Close(_) => return,
78 _ => continue,
79 };
80
81 let message = NetworkMessage(payload.clone());
82
83 if !handshaked {
84 if handshake_handle(&message, &connection_state).await.is_ok() {
85 if connection_state.handshaked.load(Ordering::SeqCst) {
86 handshaked = true;
87 }
88 }
89 } else {
90 let raw = serde_json::from_slice::<Authenticated<Value>>(&payload).unwrap();
91
92 if let Some(target_instance) = &raw.target_instance {
93 // Forward request
94 let mut instance_connections = instance_connections.lock().await;
95 let pool = instance_connections.get_or_open(&target_instance).unwrap();
96 let pool_clone = pool.clone();
97 drop(pool);
98
99 let result = wrap_forwarded(&pool_clone, raw).await;
100
101 let mut socket = connection_state.socket.lock().await;
102 let _ = socket.send(result).await;
103
104 continue;
105 }
106
107 let message_type = &raw.message_type;
108
109 match authentication_handle(message_type, &message, &connection_state).await {
110 Err(e) => {
111 let _ = connection_state.send(ConnectionError(e.to_string())).await;
112 }
113 Ok(true) => continue,
114 Ok(false) => {}
115 }
116
117 match repository_handle(message_type, &message, &connection_state).await {
118 Err(e) => {
119 let _ = connection_state.send(ConnectionError(e.to_string())).await;
120 }
121 Ok(true) => continue,
122 Ok(false) => {}
123 }
124
125 match user_handle(message_type, &message, &connection_state).await {
126 Err(e) => {
127 let _ = connection_state.send(ConnectionError(e.to_string())).await;
128 }
129 Ok(true) => continue,
130 Ok(false) => {}
131 }
132
133 match authentication_handle(message_type, &message, &connection_state).await {
134 Err(e) => {
135 let _ = connection_state.send(ConnectionError(e.to_string())).await;
136 }
137 Ok(true) => continue,
138 Ok(false) => {}
139 }
140
141 error!(
142 "Message completely unhandled: {}",
143 std::str::from_utf8(&payload).unwrap()
144 );
145 }
146 }
147 Some(Err(e)) => {
148 error!("Closing connection for {:?} for {}", e, addr);
149 return;
150 }
151 _ => {
152 info!("Unhandled");
153 continue;
154 }
155 }
156 }
157 }
158
159 #[derive(Clone)]
160 pub struct ConnectionState {
161 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
162 pub connections: Arc<Mutex<Connections>>,
163 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
164 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
165 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
166 pub addr: SocketAddr,
167 pub instance: Instance,
168 pub handshaked: Arc<AtomicBool>,
169 pub cached_keys: Arc<RwLock<HashMap<Instance, RsaPublicKey>>>,
170 }
171
172 impl ConnectionState {
173 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
174 let payload = serde_json::to_string(&message)?;
175 info!("Sending payload: {}", &payload);
176 self.socket
177 .lock()
178 .await
179 .send(Message::Binary(payload.into_bytes()))
180 .await?;
181
182 Ok(())
183 }
184 }
185