JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Change forwarding

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨2556801

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨7653⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 };
8
9 use anyhow::Error;
10 use futures_util::{SinkExt, StreamExt};
11 use giterated_models::{
12 messages::error::ConnectionError,
13 model::{authenticated::AuthenticatedPayload, instance::Instance},
14 };
15
16 use serde::Serialize;
17
18 use tokio::{net::TcpStream, sync::Mutex};
19 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
20 use toml::Table;
21
22 use crate::{
23 authentication::AuthenticationTokenGranter,
24 backend::{RepositoryBackend, SettingsBackend, UserBackend},
25 connection::forwarded::wrap_forwarded,
26 federation::connections::InstanceConnections,
27 keys::PublicKeyCache,
28 message::NetworkMessage,
29 };
30
31 use super::{
32 authentication::authentication_handle, handshake::handshake_handle,
33 repository::repository_handle, user::user_handle, Connections,
34 };
35
36 pub async fn connection_wrapper(
37 socket: WebSocketStream<TcpStream>,
38 connections: Arc<Mutex<Connections>>,
39 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
40 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
41 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
42 settings_backend: Arc<Mutex<dyn SettingsBackend>>,
43 addr: SocketAddr,
44 instance: impl ToOwned<Owned = Instance>,
45 instance_connections: Arc<Mutex<InstanceConnections>>,
46 config: Table,
47 ) {
48 let connection_state = ConnectionState {
49 socket: Arc::new(Mutex::new(socket)),
50 connections,
51 repository_backend,
52 user_backend,
53 auth_granter,
54 settings_backend,
55 addr,
56 instance: instance.to_owned(),
57 handshaked: Arc::new(AtomicBool::new(false)),
58 key_cache: Arc::default(),
59 instance_connections: instance_connections.clone(),
60 config,
61 };
62
63 let mut handshaked = false;
64
65 loop {
66 let mut socket = connection_state.socket.lock().await;
67 let message = socket.next().await;
68 drop(socket);
69
70 match message {
71 Some(Ok(message)) => {
72 let payload = match message {
73 Message::Binary(payload) => payload,
74 Message::Ping(_) => {
75 let mut socket = connection_state.socket.lock().await;
76 let _ = socket.send(Message::Pong(vec![])).await;
77 drop(socket);
78 continue;
79 }
80 Message::Close(_) => return,
81 _ => continue,
82 };
83
84 let message = NetworkMessage(payload.clone());
85
86 if !handshaked {
87 info!("im foo baring");
88 if handshake_handle(&message, &connection_state).await.is_ok() {
89 if connection_state.handshaked.load(Ordering::SeqCst) {
90 handshaked = true;
91 }
92 }
93 } else {
94 let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
95
96 if let Some(target_instance) = &raw.target_instance {
97 // Forward request
98 info!("Forwarding message to {}", target_instance.url);
99 let mut instance_connections = instance_connections.lock().await;
100 let pool = instance_connections.get_or_open(&target_instance).unwrap();
101 let pool_clone = pool.clone();
102 drop(pool);
103
104 let result = wrap_forwarded(&pool_clone, raw).await;
105
106 let mut socket = connection_state.socket.lock().await;
107 let _ = socket.send(result).await;
108
109 continue;
110 }
111
112 let message_type = &raw.message_type;
113
114 info!("Handling message with type: {}", message_type);
115
116 match authentication_handle(message_type, &message, &connection_state).await {
117 Err(e) => {
118 let _ = connection_state
119 .send_raw(ConnectionError(e.to_string()))
120 .await;
121 }
122 Ok(true) => continue,
123 Ok(false) => {}
124 }
125
126 match repository_handle(message_type, &message, &connection_state).await {
127 Err(e) => {
128 let _ = connection_state
129 .send_raw(ConnectionError(e.to_string()))
130 .await;
131 }
132 Ok(true) => continue,
133 Ok(false) => {}
134 }
135
136 match user_handle(message_type, &message, &connection_state).await {
137 Err(e) => {
138 let _ = connection_state
139 .send_raw(ConnectionError(e.to_string()))
140 .await;
141 }
142 Ok(true) => continue,
143 Ok(false) => {}
144 }
145
146 match authentication_handle(message_type, &message, &connection_state).await {
147 Err(e) => {
148 let _ = connection_state
149 .send_raw(ConnectionError(e.to_string()))
150 .await;
151 }
152 Ok(true) => continue,
153 Ok(false) => {}
154 }
155
156 error!(
157 "Message completely unhandled: {}",
158 std::str::from_utf8(&payload).unwrap()
159 );
160 }
161 }
162 Some(Err(e)) => {
163 error!("Closing connection for {:?} for {}", e, addr);
164 return;
165 }
166 _ => {
167 info!("Unhandled");
168 continue;
169 }
170 }
171 }
172 }
173
174 #[derive(Clone)]
175 pub struct ConnectionState {
176 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
177 pub connections: Arc<Mutex<Connections>>,
178 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
179 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
180 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
181 pub settings_backend: Arc<Mutex<dyn SettingsBackend>>,
182 pub addr: SocketAddr,
183 pub instance: Instance,
184 pub handshaked: Arc<AtomicBool>,
185 pub key_cache: Arc<Mutex<PublicKeyCache>>,
186 pub instance_connections: Arc<Mutex<InstanceConnections>>,
187 pub config: Table,
188 }
189
190 impl ConnectionState {
191 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
192 let payload = serde_json::to_string(&message)?;
193 info!("Sending payload: {}", &payload);
194 self.socket
195 .lock()
196 .await
197 .send(Message::Binary(payload.into_bytes()))
198 .await?;
199
200 Ok(())
201 }
202
203 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
204 let payload = serde_json::to_string(&message)?;
205 info!("Sending payload: {}", &payload);
206 self.socket
207 .lock()
208 .await
209 .send(Message::Binary(payload.into_bytes()))
210 .await?;
211
212 Ok(())
213 }
214
215 pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> {
216 let mut keys = self.key_cache.lock().await;
217 keys.get(instance).await
218 }
219 }
220