JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add more aggressive key caching

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨5bc92ad

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨7429⁩ bytes
Raw
1 use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 };
9
10 use anyhow::Error;
11 use futures_util::{SinkExt, StreamExt};
12 use giterated_models::{
13 messages::error::ConnectionError,
14 model::{
15 authenticated::{Authenticated, AuthenticatedPayload},
16 instance::Instance,
17 },
18 };
19 use rsa::RsaPublicKey;
20 use serde::Serialize;
21 use serde_json::Value;
22 use tokio::{
23 net::TcpStream,
24 sync::{Mutex, RwLock},
25 };
26 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
27
28 use crate::{
29 authentication::AuthenticationTokenGranter,
30 backend::{RepositoryBackend, UserBackend},
31 connection::forwarded::wrap_forwarded,
32 federation::connections::InstanceConnections,
33 keys::PublicKeyCache,
34 message::NetworkMessage,
35 };
36
37 use super::{
38 authentication::authentication_handle, handshake::handshake_handle,
39 repository::repository_handle, user::user_handle, Connections,
40 };
41
42 pub async fn connection_wrapper(
43 socket: WebSocketStream<TcpStream>,
44 connections: Arc<Mutex<Connections>>,
45 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
46 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
47 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
48 addr: SocketAddr,
49 instance: impl ToOwned<Owned = Instance>,
50 instance_connections: Arc<Mutex<InstanceConnections>>,
51 ) {
52 let connection_state = ConnectionState {
53 socket: Arc::new(Mutex::new(socket)),
54 connections,
55 repository_backend,
56 user_backend,
57 auth_granter,
58 addr,
59 instance: instance.to_owned(),
60 handshaked: Arc::new(AtomicBool::new(false)),
61 key_cache: Arc::default(),
62 };
63
64 let mut handshaked = false;
65
66 loop {
67 let mut socket = connection_state.socket.lock().await;
68 let message = socket.next().await;
69 drop(socket);
70
71 match message {
72 Some(Ok(message)) => {
73 let payload = match message {
74 Message::Binary(payload) => payload,
75 Message::Ping(_) => {
76 let mut socket = connection_state.socket.lock().await;
77 let _ = socket.send(Message::Pong(vec![])).await;
78 drop(socket);
79 continue;
80 }
81 Message::Close(_) => return,
82 _ => continue,
83 };
84
85 let message = NetworkMessage(payload.clone());
86
87 if !handshaked {
88 info!("im foo baring");
89 if handshake_handle(&message, &connection_state).await.is_ok() {
90 if connection_state.handshaked.load(Ordering::SeqCst) {
91 handshaked = true;
92 }
93 }
94 } else {
95 let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
96
97 if let Some(target_instance) = &raw.target_instance {
98 // Forward request
99 info!("Forwarding message to {}", target_instance.url);
100 let mut instance_connections = instance_connections.lock().await;
101 let pool = instance_connections.get_or_open(&target_instance).unwrap();
102 let pool_clone = pool.clone();
103 drop(pool);
104
105 let result = wrap_forwarded(&pool_clone, raw).await;
106
107 let mut socket = connection_state.socket.lock().await;
108 let _ = socket.send(result).await;
109
110 continue;
111 }
112
113 let message_type = &raw.message_type;
114
115 info!("Handling message with type: {}", message_type);
116
117 match authentication_handle(message_type, &message, &connection_state).await {
118 Err(e) => {
119 let _ = connection_state
120 .send_raw(ConnectionError(e.to_string()))
121 .await;
122 }
123 Ok(true) => continue,
124 Ok(false) => {}
125 }
126
127 match repository_handle(message_type, &message, &connection_state).await {
128 Err(e) => {
129 let _ = connection_state
130 .send_raw(ConnectionError(e.to_string()))
131 .await;
132 }
133 Ok(true) => continue,
134 Ok(false) => {}
135 }
136
137 match user_handle(message_type, &message, &connection_state).await {
138 Err(e) => {
139 let _ = connection_state
140 .send_raw(ConnectionError(e.to_string()))
141 .await;
142 }
143 Ok(true) => continue,
144 Ok(false) => {}
145 }
146
147 match authentication_handle(message_type, &message, &connection_state).await {
148 Err(e) => {
149 let _ = connection_state
150 .send_raw(ConnectionError(e.to_string()))
151 .await;
152 }
153 Ok(true) => continue,
154 Ok(false) => {}
155 }
156
157 error!(
158 "Message completely unhandled: {}",
159 std::str::from_utf8(&payload).unwrap()
160 );
161 }
162 }
163 Some(Err(e)) => {
164 error!("Closing connection for {:?} for {}", e, addr);
165 return;
166 }
167 _ => {
168 info!("Unhandled");
169 continue;
170 }
171 }
172 }
173 }
174
175 #[derive(Clone)]
176 pub struct ConnectionState {
177 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
178 pub connections: Arc<Mutex<Connections>>,
179 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
180 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
181 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
182 pub addr: SocketAddr,
183 pub instance: Instance,
184 pub handshaked: Arc<AtomicBool>,
185 pub key_cache: Arc<Mutex<PublicKeyCache>>,
186 }
187
188 impl ConnectionState {
189 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
190 let payload = serde_json::to_string(&message)?;
191 info!("Sending payload: {}", &payload);
192 self.socket
193 .lock()
194 .await
195 .send(Message::Binary(payload.into_bytes()))
196 .await?;
197
198 Ok(())
199 }
200
201 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
202 let payload = serde_json::to_string(&message)?;
203 info!("Sending payload: {}", &payload);
204 self.socket
205 .lock()
206 .await
207 .send(Message::Binary(payload.into_bytes()))
208 .await?;
209
210 Ok(())
211 }
212
213 pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> {
214 let mut keys = self.key_cache.lock().await;
215 keys.get(instance).await
216 }
217 }
218