JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add repository settings

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨f8eaf38

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨7455⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 };
8
9 use anyhow::Error;
10 use futures_util::{SinkExt, StreamExt};
11 use giterated_models::{
12 messages::error::ConnectionError,
13 model::{authenticated::AuthenticatedPayload, instance::Instance},
14 };
15
16 use serde::Serialize;
17
18 use tokio::{net::TcpStream, sync::Mutex};
19 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
20
21 use crate::{
22 authentication::AuthenticationTokenGranter,
23 backend::{RepositoryBackend, SettingsBackend, UserBackend},
24 connection::forwarded::wrap_forwarded,
25 federation::connections::InstanceConnections,
26 keys::PublicKeyCache,
27 message::NetworkMessage,
28 };
29
30 use super::{
31 authentication::authentication_handle, handshake::handshake_handle,
32 repository::repository_handle, user::user_handle, Connections,
33 };
34
35 pub async fn connection_wrapper(
36 socket: WebSocketStream<TcpStream>,
37 connections: Arc<Mutex<Connections>>,
38 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
39 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
40 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
41 settings_backend: Arc<Mutex<dyn SettingsBackend>>,
42 addr: SocketAddr,
43 instance: impl ToOwned<Owned = Instance>,
44 instance_connections: Arc<Mutex<InstanceConnections>>,
45 ) {
46 let connection_state = ConnectionState {
47 socket: Arc::new(Mutex::new(socket)),
48 connections,
49 repository_backend,
50 user_backend,
51 auth_granter,
52 settings_backend,
53 addr,
54 instance: instance.to_owned(),
55 handshaked: Arc::new(AtomicBool::new(false)),
56 key_cache: Arc::default(),
57 };
58
59 let mut handshaked = false;
60
61 loop {
62 let mut socket = connection_state.socket.lock().await;
63 let message = socket.next().await;
64 drop(socket);
65
66 match message {
67 Some(Ok(message)) => {
68 let payload = match message {
69 Message::Binary(payload) => payload,
70 Message::Ping(_) => {
71 let mut socket = connection_state.socket.lock().await;
72 let _ = socket.send(Message::Pong(vec![])).await;
73 drop(socket);
74 continue;
75 }
76 Message::Close(_) => return,
77 _ => continue,
78 };
79
80 let message = NetworkMessage(payload.clone());
81
82 if !handshaked {
83 info!("im foo baring");
84 if handshake_handle(&message, &connection_state).await.is_ok() {
85 if connection_state.handshaked.load(Ordering::SeqCst) {
86 handshaked = true;
87 }
88 }
89 } else {
90 let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
91
92 if let Some(target_instance) = &raw.target_instance {
93 // Forward request
94 info!("Forwarding message to {}", target_instance.url);
95 let mut instance_connections = instance_connections.lock().await;
96 let pool = instance_connections.get_or_open(&target_instance).unwrap();
97 let pool_clone = pool.clone();
98 drop(pool);
99
100 let result = wrap_forwarded(&pool_clone, raw).await;
101
102 let mut socket = connection_state.socket.lock().await;
103 let _ = socket.send(result).await;
104
105 continue;
106 }
107
108 let message_type = &raw.message_type;
109
110 info!("Handling message with type: {}", message_type);
111
112 match authentication_handle(message_type, &message, &connection_state).await {
113 Err(e) => {
114 let _ = connection_state
115 .send_raw(ConnectionError(e.to_string()))
116 .await;
117 }
118 Ok(true) => continue,
119 Ok(false) => {}
120 }
121
122 match repository_handle(message_type, &message, &connection_state).await {
123 Err(e) => {
124 let _ = connection_state
125 .send_raw(ConnectionError(e.to_string()))
126 .await;
127 }
128 Ok(true) => continue,
129 Ok(false) => {}
130 }
131
132 match user_handle(message_type, &message, &connection_state).await {
133 Err(e) => {
134 let _ = connection_state
135 .send_raw(ConnectionError(e.to_string()))
136 .await;
137 }
138 Ok(true) => continue,
139 Ok(false) => {}
140 }
141
142 match authentication_handle(message_type, &message, &connection_state).await {
143 Err(e) => {
144 let _ = connection_state
145 .send_raw(ConnectionError(e.to_string()))
146 .await;
147 }
148 Ok(true) => continue,
149 Ok(false) => {}
150 }
151
152 error!(
153 "Message completely unhandled: {}",
154 std::str::from_utf8(&payload).unwrap()
155 );
156 }
157 }
158 Some(Err(e)) => {
159 error!("Closing connection for {:?} for {}", e, addr);
160 return;
161 }
162 _ => {
163 info!("Unhandled");
164 continue;
165 }
166 }
167 }
168 }
169
170 #[derive(Clone)]
171 pub struct ConnectionState {
172 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
173 pub connections: Arc<Mutex<Connections>>,
174 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
175 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
176 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
177 pub settings_backend: Arc<Mutex<dyn SettingsBackend>>,
178 pub addr: SocketAddr,
179 pub instance: Instance,
180 pub handshaked: Arc<AtomicBool>,
181 pub key_cache: Arc<Mutex<PublicKeyCache>>,
182 }
183
184 impl ConnectionState {
185 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
186 let payload = serde_json::to_string(&message)?;
187 info!("Sending payload: {}", &payload);
188 self.socket
189 .lock()
190 .await
191 .send(Message::Binary(payload.into_bytes()))
192 .await?;
193
194 Ok(())
195 }
196
197 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
198 let payload = serde_json::to_string(&message)?;
199 info!("Sending payload: {}", &payload);
200 self.socket
201 .lock()
202 .await
203 .send(Message::Binary(payload.into_bytes()))
204 .await?;
205
206 Ok(())
207 }
208
209 pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> {
210 let mut keys = self.key_cache.lock().await;
211 keys.get(instance).await
212 }
213 }
214