JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Implement Debug on all messages

Type: Fix

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨249c88e

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨7396⁩ bytes
Raw
1 use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 };
9
10 use anyhow::Error;
11 use futures_util::{SinkExt, StreamExt};
12 use giterated_models::{
13 messages::error::ConnectionError,
14 model::{
15 authenticated::{Authenticated, AuthenticatedPayload},
16 instance::Instance,
17 },
18 };
19 use rsa::RsaPublicKey;
20 use serde::Serialize;
21 use serde_json::Value;
22 use tokio::{
23 net::TcpStream,
24 sync::{Mutex, RwLock},
25 };
26 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
27
28 use crate::{
29 authentication::AuthenticationTokenGranter,
30 backend::{RepositoryBackend, UserBackend},
31 connection::forwarded::wrap_forwarded,
32 federation::connections::InstanceConnections,
33 message::NetworkMessage,
34 };
35
36 use super::{
37 authentication::authentication_handle, handshake::handshake_handle,
38 repository::repository_handle, user::user_handle, Connections,
39 };
40
41 pub async fn connection_wrapper(
42 socket: WebSocketStream<TcpStream>,
43 connections: Arc<Mutex<Connections>>,
44 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
45 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
46 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
47 addr: SocketAddr,
48 instance: impl ToOwned<Owned = Instance>,
49 instance_connections: Arc<Mutex<InstanceConnections>>,
50 ) {
51 let connection_state = ConnectionState {
52 socket: Arc::new(Mutex::new(socket)),
53 connections,
54 repository_backend,
55 user_backend,
56 auth_granter,
57 addr,
58 instance: instance.to_owned(),
59 handshaked: Arc::new(AtomicBool::new(false)),
60 cached_keys: Arc::default(),
61 };
62
63 let mut handshaked = false;
64
65 loop {
66 let mut socket = connection_state.socket.lock().await;
67 let message = socket.next().await;
68 drop(socket);
69
70 match message {
71 Some(Ok(message)) => {
72 let payload = match message {
73 Message::Binary(payload) => payload,
74 Message::Ping(_) => {
75 let mut socket = connection_state.socket.lock().await;
76 let _ = socket.send(Message::Pong(vec![])).await;
77 drop(socket);
78 continue;
79 }
80 Message::Close(_) => return,
81 _ => continue,
82 };
83
84 info!(
85 "Received payload: {}",
86 std::str::from_utf8(&payload).unwrap()
87 );
88
89 let message = NetworkMessage(payload.clone());
90
91 if !handshaked {
92 info!("im foo baring");
93 if handshake_handle(&message, &connection_state).await.is_ok() {
94 if connection_state.handshaked.load(Ordering::SeqCst) {
95 handshaked = true;
96 }
97 }
98 } else {
99 let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
100
101 if let Some(target_instance) = &raw.target_instance {
102 // Forward request
103 info!("Forwarding message to {}", target_instance.url);
104 let mut instance_connections = instance_connections.lock().await;
105 let pool = instance_connections.get_or_open(&target_instance).unwrap();
106 let pool_clone = pool.clone();
107 drop(pool);
108
109 let result = wrap_forwarded(&pool_clone, raw).await;
110
111 let mut socket = connection_state.socket.lock().await;
112 let _ = socket.send(result).await;
113
114 continue;
115 }
116
117 let message_type = &raw.message_type;
118
119 info!("Handling message with type: {}", message_type);
120
121 match authentication_handle(message_type, &message, &connection_state).await {
122 Err(e) => {
123 let _ = connection_state
124 .send_raw(ConnectionError(e.to_string()))
125 .await;
126 }
127 Ok(true) => continue,
128 Ok(false) => {}
129 }
130
131 match repository_handle(message_type, &message, &connection_state).await {
132 Err(e) => {
133 let _ = connection_state
134 .send_raw(ConnectionError(e.to_string()))
135 .await;
136 }
137 Ok(true) => continue,
138 Ok(false) => {}
139 }
140
141 match user_handle(message_type, &message, &connection_state).await {
142 Err(e) => {
143 let _ = connection_state
144 .send_raw(ConnectionError(e.to_string()))
145 .await;
146 }
147 Ok(true) => continue,
148 Ok(false) => {}
149 }
150
151 match authentication_handle(message_type, &message, &connection_state).await {
152 Err(e) => {
153 let _ = connection_state
154 .send_raw(ConnectionError(e.to_string()))
155 .await;
156 }
157 Ok(true) => continue,
158 Ok(false) => {}
159 }
160
161 error!(
162 "Message completely unhandled: {}",
163 std::str::from_utf8(&payload).unwrap()
164 );
165 }
166 }
167 Some(Err(e)) => {
168 error!("Closing connection for {:?} for {}", e, addr);
169 return;
170 }
171 _ => {
172 info!("Unhandled");
173 continue;
174 }
175 }
176 }
177 }
178
179 #[derive(Clone)]
180 pub struct ConnectionState {
181 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
182 pub connections: Arc<Mutex<Connections>>,
183 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
184 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
185 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
186 pub addr: SocketAddr,
187 pub instance: Instance,
188 pub handshaked: Arc<AtomicBool>,
189 pub cached_keys: Arc<RwLock<HashMap<Instance, RsaPublicKey>>>,
190 }
191
192 impl ConnectionState {
193 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
194 let payload = serde_json::to_string(&message)?;
195 info!("Sending payload: {}", &payload);
196 self.socket
197 .lock()
198 .await
199 .send(Message::Binary(payload.into_bytes()))
200 .await?;
201
202 Ok(())
203 }
204
205 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
206 let payload = serde_json::to_string(&message)?;
207 info!("Sending payload: {}", &payload);
208 self.socket
209 .lock()
210 .await
211 .send(Message::Binary(payload.into_bytes()))
212 .await?;
213
214 Ok(())
215 }
216 }
217