JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Major connection refactor base

Type: Refactor

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨8dcc111

⁨src/connection.rs⁩ - ⁨10013⁩ bytes
Raw
1 pub mod authentication;
2 pub mod handshake;
3 pub mod repository;
4 pub mod user;
5 pub mod wrapper;
6
7 use std::{any::type_name, collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc};
8
9 use anyhow::Error;
10 use futures_util::{stream::StreamExt, SinkExt};
11 use semver::Version;
12 use serde::{de::DeserializeOwned, Serialize};
13 use tokio::{
14 net::TcpStream,
15 sync::{
16 broadcast::{Receiver, Sender},
17 Mutex,
18 },
19 task::JoinHandle,
20 };
21 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
22
23 use crate::{
24 authentication::AuthenticationTokenGranter,
25 backend::{DiscoveryBackend, RepositoryBackend, UserBackend},
26 handshake::{HandshakeFinalize, HandshakeMessage, HandshakeResponse},
27 listener::Listeners,
28 messages::{
29 authentication::{
30 AuthenticationMessage, AuthenticationRequest, AuthenticationResponse,
31 TokenExtensionResponse,
32 },
33 repository::{
34 RepositoryMessage, RepositoryMessageKind, RepositoryRequest, RepositoryResponse,
35 },
36 user::{
37 UserMessage, UserMessageKind, UserMessageRequest, UserMessageResponse,
38 UserRepositoriesResponse,
39 },
40 ErrorMessage, MessageKind,
41 },
42 model::{
43 instance::{Instance, InstanceMeta},
44 repository::Repository,
45 user::User,
46 },
47 validate_version, version,
48 };
49
50 #[derive(Debug, thiserror::Error)]
51 pub enum ConnectionError {
52 #[error("connection error message {0}")]
53 ErrorMessage(#[from] ErrorMessage),
54 #[error("connection should close")]
55 Shutdown,
56 #[error("internal error {0}")]
57 InternalError(#[from] Error),
58 }
59
60 pub struct RawConnection {
61 pub task: JoinHandle<()>,
62 }
63
64 pub struct InstanceConnection {
65 pub instance: InstanceMeta,
66 pub sender: Sender<MessageKind>,
67 pub task: JoinHandle<()>,
68 }
69
70 /// Represents a connection which hasn't finished the handshake.
71 pub struct UnestablishedConnection {
72 pub socket: WebSocketStream<TcpStream>,
73 }
74
75 #[derive(Default)]
76 pub struct Connections {
77 pub connections: Vec<RawConnection>,
78 pub instance_connections: HashMap<Instance, InstanceConnection>,
79 }
80
81 pub async fn connection_worker(
82 mut socket: &mut WebSocketStream<TcpStream>,
83 handshaked: &mut bool,
84 listeners: &Arc<Mutex<Listeners>>,
85 connections: &Arc<Mutex<Connections>>,
86 backend: &Arc<Mutex<dyn RepositoryBackend + Send>>,
87 user_backend: &Arc<Mutex<dyn UserBackend + Send>>,
88 auth_granter: &Arc<Mutex<AuthenticationTokenGranter>>,
89 discovery_backend: &Arc<Mutex<dyn DiscoveryBackend + Send>>,
90 addr: &SocketAddr,
91 ) -> Result<(), ConnectionError> {
92 let this_instance = Instance {
93 url: String::from("giterated.dev"),
94 };
95
96 let message = socket
97 .next()
98 .await
99 .ok_or_else(|| ConnectionError::Shutdown)?
100 .map_err(|e| Error::from(e))?;
101
102 let payload = match message {
103 Message::Text(text) => text.into_bytes(),
104 Message::Binary(bytes) => bytes,
105 Message::Ping(_) => return Ok(()),
106 Message::Pong(_) => return Ok(()),
107 Message::Close(_) => {
108 info!("Closing connection with {}.", addr);
109
110 return Err(ConnectionError::Shutdown);
111 }
112 _ => unreachable!(),
113 };
114
115 let message = serde_json::from_slice::<MessageKind>(&payload).map_err(|e| Error::from(e))?;
116
117 if let MessageKind::Handshake(handshake) = message {
118 match handshake {
119 HandshakeMessage::Initiate(request) => {
120 unimplemented!()
121 }
122 HandshakeMessage::Response(response) => {
123 unimplemented!()
124 }
125 HandshakeMessage::Finalize(response) => {
126 unimplemented!()
127 }
128 }
129 }
130
131 if !*handshaked {
132 return Ok(());
133 }
134
135 if let MessageKind::Repository(repository) = &message {
136 if repository.target.instance != this_instance {
137 info!("Forwarding command to {}", repository.target.instance.url);
138 // We need to send this command to a different instance
139
140 let mut listener = send_and_get_listener(message, &listeners, &connections).await;
141
142 // Wait for response
143 while let Ok(message) = listener.recv().await {
144 if let MessageKind::Repository(RepositoryMessage {
145 command: RepositoryMessageKind::Response(_),
146 ..
147 }) = message
148 {
149 let _result = send(&mut socket, message).await;
150 }
151 }
152
153 return Ok(());
154 } else {
155 // This message is targeting this instance
156 match &repository.command {
157 RepositoryMessageKind::Request(request) => match request.clone() {
158 RepositoryRequest::CreateRepository(request) => {
159 unimplemented!();
160 }
161 RepositoryRequest::RepositoryFileInspect(request) => {
162 unimplemented!()
163 }
164 RepositoryRequest::RepositoryInfo(request) => {
165 unimplemented!()
166 }
167 RepositoryRequest::IssuesCount(request) => {
168 unimplemented!()
169 }
170 RepositoryRequest::IssueLabels(request) => {
171 unimplemented!()
172 }
173 RepositoryRequest::Issues(request) => {
174 unimplemented!();
175 }
176 },
177 RepositoryMessageKind::Response(_response) => {
178 unreachable!()
179 }
180 }
181 }
182 }
183
184 if let MessageKind::Authentication(authentication) = &message {
185 match authentication {
186 AuthenticationMessage::Request(request) => match request {
187 AuthenticationRequest::AuthenticationToken(token) => {
188 unimplemented!()
189 }
190 AuthenticationRequest::TokenExtension(request) => {
191 unimplemented!()
192 }
193 AuthenticationRequest::RegisterAccount(request) => {
194 unimplemented!()
195 }
196 },
197 AuthenticationMessage::Response(_) => unreachable!(),
198 }
199 }
200
201 if let MessageKind::Discovery(message) = &message {
202 let mut backend = discovery_backend.lock().await;
203 backend.try_handle(message).await?;
204
205 return Ok(());
206 }
207
208 if let MessageKind::User(message) = &message {
209 match &message.message {
210 UserMessageKind::Request(request) => match request {
211 UserMessageRequest::DisplayName(request) => {
212 unimplemented!()
213 }
214 UserMessageRequest::DisplayImage(request) => {
215 unimplemented!()
216 }
217 UserMessageRequest::Bio(request) => {
218 unimplemented!()
219 }
220 UserMessageRequest::Repositories(request) => {
221 unimplemented!()
222 }
223 },
224 UserMessageKind::Response(_) => unreachable!(),
225 }
226 }
227
228 Ok(())
229 }
230
231 async fn send_and_get_listener(
232 message: MessageKind,
233 listeners: &Arc<Mutex<Listeners>>,
234 connections: &Arc<Mutex<Connections>>,
235 ) -> Receiver<MessageKind> {
236 let (instance, user, repository): (Option<Instance>, Option<User>, Option<Repository>) =
237 match &message {
238 MessageKind::Handshake(_) => {
239 todo!()
240 }
241 MessageKind::Repository(repository) => (None, None, Some(repository.target.clone())),
242 MessageKind::Authentication(_) => todo!(),
243 MessageKind::Discovery(_) => todo!(),
244 MessageKind::User(user) => todo!(),
245 MessageKind::Error(_) => todo!(),
246 };
247
248 let target = match (&instance, &user, &repository) {
249 (Some(instance), _, _) => instance.clone(),
250 (_, Some(user), _) => user.instance.clone(),
251 (_, _, Some(repository)) => repository.instance.clone(),
252 _ => unreachable!(),
253 };
254
255 let mut listeners = listeners.lock().await;
256 let listener = listeners.add(instance, user, repository);
257 drop(listeners);
258
259 let connections = connections.lock().await;
260
261 if let Some(connection) = connections.instance_connections.get(&target) {
262 if let Err(_) = connection.sender.send(message) {
263 error!("Error sending message.");
264 }
265 } else {
266 error!("Unable to message {}, this is a bug.", target.url);
267
268 panic!();
269 }
270
271 drop(connections);
272
273 listener
274 }
275
276 async fn send<T: Serialize>(
277 socket: &mut WebSocketStream<TcpStream>,
278 message: T,
279 ) -> Result<(), Error> {
280 socket
281 .send(Message::Binary(serde_json::to_vec(&message)?))
282 .await?;
283
284 Ok(())
285 }
286
287 #[derive(Debug, thiserror::Error)]
288 #[error("handler did not handle")]
289 pub struct HandlerUnhandled;
290
291 pub trait MessageHandling<A, M, R> {
292 fn message_type() -> &'static str;
293 }
294
295 impl<T1, F, M, R> MessageHandling<(T1,), M, R> for F
296 where
297 F: FnOnce(T1) -> R,
298 T1: Serialize + DeserializeOwned,
299 {
300 fn message_type() -> &'static str {
301 type_name::<T1>()
302 }
303 }
304
305 impl<T1, T2, F, M, R> MessageHandling<(T1, T2), M, R> for F
306 where
307 F: FnOnce(T1, T2) -> R,
308 T1: Serialize + DeserializeOwned,
309 {
310 fn message_type() -> &'static str {
311 type_name::<T1>()
312 }
313 }
314
315 impl<T1, T2, T3, F, M, R> MessageHandling<(T1, T2, T3), M, R> for F
316 where
317 F: FnOnce(T1, T2, T3) -> R,
318 T1: Serialize + DeserializeOwned,
319 {
320 fn message_type() -> &'static str {
321 type_name::<T1>()
322 }
323 }
324
325 impl<T1, T2, T3, T4, F, M, R> MessageHandling<(T1, T2, T3, T4), M, R> for F
326 where
327 F: FnOnce(T1, T2, T3, T4) -> R,
328 T1: Serialize + DeserializeOwned,
329 {
330 fn message_type() -> &'static str {
331 type_name::<T1>()
332 }
333 }
334