1 |
pub mod authentication;
|
2 |
pub mod handshake;
|
3 |
pub mod repository;
|
4 |
pub mod user;
|
5 |
pub mod wrapper;
|
6 |
|
7 |
use std::{any::type_name, collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc};
|
8 |
|
9 |
use anyhow::Error;
|
10 |
use futures_util::{stream::StreamExt, SinkExt};
|
11 |
use semver::Version;
|
12 |
use serde::{de::DeserializeOwned, Serialize};
|
13 |
use tokio::{
|
14 |
net::TcpStream,
|
15 |
sync::{
|
16 |
broadcast::{Receiver, Sender},
|
17 |
Mutex,
|
18 |
},
|
19 |
task::JoinHandle,
|
20 |
};
|
21 |
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
|
22 |
|
23 |
use crate::{
|
24 |
authentication::AuthenticationTokenGranter,
|
25 |
backend::{DiscoveryBackend, RepositoryBackend, UserBackend},
|
26 |
handshake::{HandshakeFinalize, HandshakeMessage, HandshakeResponse},
|
27 |
listener::Listeners,
|
28 |
messages::{
|
29 |
authentication::{
|
30 |
AuthenticationMessage, AuthenticationRequest, AuthenticationResponse,
|
31 |
TokenExtensionResponse,
|
32 |
},
|
33 |
repository::{
|
34 |
RepositoryMessage, RepositoryMessageKind, RepositoryRequest, RepositoryResponse,
|
35 |
},
|
36 |
user::{
|
37 |
UserMessage, UserMessageKind, UserMessageRequest, UserMessageResponse,
|
38 |
UserRepositoriesResponse,
|
39 |
},
|
40 |
ErrorMessage, MessageKind,
|
41 |
},
|
42 |
model::{
|
43 |
instance::{Instance, InstanceMeta},
|
44 |
repository::Repository,
|
45 |
user::User,
|
46 |
},
|
47 |
validate_version, version,
|
48 |
};
|
49 |
|
50 |
#[derive(Debug, thiserror::Error)]
|
51 |
pub enum ConnectionError {
|
52 |
#[error("connection error message {0}")]
|
53 |
ErrorMessage(#[from] ErrorMessage),
|
54 |
#[error("connection should close")]
|
55 |
Shutdown,
|
56 |
#[error("internal error {0}")]
|
57 |
InternalError(#[from] Error),
|
58 |
}
|
59 |
|
60 |
pub struct RawConnection {
|
61 |
pub task: JoinHandle<()>,
|
62 |
}
|
63 |
|
64 |
pub struct InstanceConnection {
|
65 |
pub instance: InstanceMeta,
|
66 |
pub sender: Sender<MessageKind>,
|
67 |
pub task: JoinHandle<()>,
|
68 |
}
|
69 |
|
70 |
|
71 |
pub struct UnestablishedConnection {
|
72 |
pub socket: WebSocketStream<TcpStream>,
|
73 |
}
|
74 |
|
75 |
#[derive(Default)]
|
76 |
pub struct Connections {
|
77 |
pub connections: Vec<RawConnection>,
|
78 |
pub instance_connections: HashMap<Instance, InstanceConnection>,
|
79 |
}
|
80 |
|
81 |
pub async fn connection_worker(
|
82 |
mut socket: &mut WebSocketStream<TcpStream>,
|
83 |
handshaked: &mut bool,
|
84 |
listeners: &Arc<Mutex<Listeners>>,
|
85 |
connections: &Arc<Mutex<Connections>>,
|
86 |
backend: &Arc<Mutex<dyn RepositoryBackend + Send>>,
|
87 |
user_backend: &Arc<Mutex<dyn UserBackend + Send>>,
|
88 |
auth_granter: &Arc<Mutex<AuthenticationTokenGranter>>,
|
89 |
discovery_backend: &Arc<Mutex<dyn DiscoveryBackend + Send>>,
|
90 |
addr: &SocketAddr,
|
91 |
) -> Result<(), ConnectionError> {
|
92 |
let this_instance = Instance {
|
93 |
url: String::from("giterated.dev"),
|
94 |
};
|
95 |
|
96 |
let message = socket
|
97 |
.next()
|
98 |
.await
|
99 |
.ok_or_else(|| ConnectionError::Shutdown)?
|
100 |
.map_err(|e| Error::from(e))?;
|
101 |
|
102 |
let payload = match message {
|
103 |
Message::Text(text) => text.into_bytes(),
|
104 |
Message::Binary(bytes) => bytes,
|
105 |
Message::Ping(_) => return Ok(()),
|
106 |
Message::Pong(_) => return Ok(()),
|
107 |
Message::Close(_) => {
|
108 |
info!("Closing connection with {}.", addr);
|
109 |
|
110 |
return Err(ConnectionError::Shutdown);
|
111 |
}
|
112 |
_ => unreachable!(),
|
113 |
};
|
114 |
|
115 |
let message = serde_json::from_slice::<MessageKind>(&payload).map_err(|e| Error::from(e))?;
|
116 |
|
117 |
if let MessageKind::Handshake(handshake) = message {
|
118 |
match handshake {
|
119 |
HandshakeMessage::Initiate(request) => {
|
120 |
unimplemented!()
|
121 |
}
|
122 |
HandshakeMessage::Response(response) => {
|
123 |
unimplemented!()
|
124 |
}
|
125 |
HandshakeMessage::Finalize(response) => {
|
126 |
unimplemented!()
|
127 |
}
|
128 |
}
|
129 |
}
|
130 |
|
131 |
if !*handshaked {
|
132 |
return Ok(());
|
133 |
}
|
134 |
|
135 |
if let MessageKind::Repository(repository) = &message {
|
136 |
if repository.target.instance != this_instance {
|
137 |
info!("Forwarding command to {}", repository.target.instance.url);
|
138 |
|
139 |
|
140 |
let mut listener = send_and_get_listener(message, &listeners, &connections).await;
|
141 |
|
142 |
|
143 |
while let Ok(message) = listener.recv().await {
|
144 |
if let MessageKind::Repository(RepositoryMessage {
|
145 |
command: RepositoryMessageKind::Response(_),
|
146 |
..
|
147 |
}) = message
|
148 |
{
|
149 |
let _result = send(&mut socket, message).await;
|
150 |
}
|
151 |
}
|
152 |
|
153 |
return Ok(());
|
154 |
} else {
|
155 |
|
156 |
match &repository.command {
|
157 |
RepositoryMessageKind::Request(request) => match request.clone() {
|
158 |
RepositoryRequest::CreateRepository(request) => {
|
159 |
unimplemented!();
|
160 |
}
|
161 |
RepositoryRequest::RepositoryFileInspect(request) => {
|
162 |
unimplemented!()
|
163 |
}
|
164 |
RepositoryRequest::RepositoryInfo(request) => {
|
165 |
unimplemented!()
|
166 |
}
|
167 |
RepositoryRequest::IssuesCount(request) => {
|
168 |
unimplemented!()
|
169 |
}
|
170 |
RepositoryRequest::IssueLabels(request) => {
|
171 |
unimplemented!()
|
172 |
}
|
173 |
RepositoryRequest::Issues(request) => {
|
174 |
unimplemented!();
|
175 |
}
|
176 |
},
|
177 |
RepositoryMessageKind::Response(_response) => {
|
178 |
unreachable!()
|
179 |
}
|
180 |
}
|
181 |
}
|
182 |
}
|
183 |
|
184 |
if let MessageKind::Authentication(authentication) = &message {
|
185 |
match authentication {
|
186 |
AuthenticationMessage::Request(request) => match request {
|
187 |
AuthenticationRequest::AuthenticationToken(token) => {
|
188 |
unimplemented!()
|
189 |
}
|
190 |
AuthenticationRequest::TokenExtension(request) => {
|
191 |
unimplemented!()
|
192 |
}
|
193 |
AuthenticationRequest::RegisterAccount(request) => {
|
194 |
unimplemented!()
|
195 |
}
|
196 |
},
|
197 |
AuthenticationMessage::Response(_) => unreachable!(),
|
198 |
}
|
199 |
}
|
200 |
|
201 |
if let MessageKind::Discovery(message) = &message {
|
202 |
let mut backend = discovery_backend.lock().await;
|
203 |
backend.try_handle(message).await?;
|
204 |
|
205 |
return Ok(());
|
206 |
}
|
207 |
|
208 |
if let MessageKind::User(message) = &message {
|
209 |
match &message.message {
|
210 |
UserMessageKind::Request(request) => match request {
|
211 |
UserMessageRequest::DisplayName(request) => {
|
212 |
unimplemented!()
|
213 |
}
|
214 |
UserMessageRequest::DisplayImage(request) => {
|
215 |
unimplemented!()
|
216 |
}
|
217 |
UserMessageRequest::Bio(request) => {
|
218 |
unimplemented!()
|
219 |
}
|
220 |
UserMessageRequest::Repositories(request) => {
|
221 |
unimplemented!()
|
222 |
}
|
223 |
},
|
224 |
UserMessageKind::Response(_) => unreachable!(),
|
225 |
}
|
226 |
}
|
227 |
|
228 |
Ok(())
|
229 |
}
|
230 |
|
231 |
async fn send_and_get_listener(
|
232 |
message: MessageKind,
|
233 |
listeners: &Arc<Mutex<Listeners>>,
|
234 |
connections: &Arc<Mutex<Connections>>,
|
235 |
) -> Receiver<MessageKind> {
|
236 |
let (instance, user, repository): (Option<Instance>, Option<User>, Option<Repository>) =
|
237 |
match &message {
|
238 |
MessageKind::Handshake(_) => {
|
239 |
todo!()
|
240 |
}
|
241 |
MessageKind::Repository(repository) => (None, None, Some(repository.target.clone())),
|
242 |
MessageKind::Authentication(_) => todo!(),
|
243 |
MessageKind::Discovery(_) => todo!(),
|
244 |
MessageKind::User(user) => todo!(),
|
245 |
MessageKind::Error(_) => todo!(),
|
246 |
};
|
247 |
|
248 |
let target = match (&instance, &user, &repository) {
|
249 |
(Some(instance), _, _) => instance.clone(),
|
250 |
(_, Some(user), _) => user.instance.clone(),
|
251 |
(_, _, Some(repository)) => repository.instance.clone(),
|
252 |
_ => unreachable!(),
|
253 |
};
|
254 |
|
255 |
let mut listeners = listeners.lock().await;
|
256 |
let listener = listeners.add(instance, user, repository);
|
257 |
drop(listeners);
|
258 |
|
259 |
let connections = connections.lock().await;
|
260 |
|
261 |
if let Some(connection) = connections.instance_connections.get(&target) {
|
262 |
if let Err(_) = connection.sender.send(message) {
|
263 |
error!("Error sending message.");
|
264 |
}
|
265 |
} else {
|
266 |
error!("Unable to message {}, this is a bug.", target.url);
|
267 |
|
268 |
panic!();
|
269 |
}
|
270 |
|
271 |
drop(connections);
|
272 |
|
273 |
listener
|
274 |
}
|
275 |
|
276 |
async fn send<T: Serialize>(
|
277 |
socket: &mut WebSocketStream<TcpStream>,
|
278 |
message: T,
|
279 |
) -> Result<(), Error> {
|
280 |
socket
|
281 |
.send(Message::Binary(serde_json::to_vec(&message)?))
|
282 |
.await?;
|
283 |
|
284 |
Ok(())
|
285 |
}
|
286 |
|
287 |
#[derive(Debug, thiserror::Error)]
|
288 |
#[error("handler did not handle")]
|
289 |
pub struct HandlerUnhandled;
|
290 |
|
291 |
pub trait MessageHandling<A, M, R> {
|
292 |
fn message_type() -> &'static str;
|
293 |
}
|
294 |
|
295 |
impl<T1, F, M, R> MessageHandling<(T1,), M, R> for F
|
296 |
where
|
297 |
F: FnOnce(T1) -> R,
|
298 |
T1: Serialize + DeserializeOwned,
|
299 |
{
|
300 |
fn message_type() -> &'static str {
|
301 |
type_name::<T1>()
|
302 |
}
|
303 |
}
|
304 |
|
305 |
impl<T1, T2, F, M, R> MessageHandling<(T1, T2), M, R> for F
|
306 |
where
|
307 |
F: FnOnce(T1, T2) -> R,
|
308 |
T1: Serialize + DeserializeOwned,
|
309 |
{
|
310 |
fn message_type() -> &'static str {
|
311 |
type_name::<T1>()
|
312 |
}
|
313 |
}
|
314 |
|
315 |
impl<T1, T2, T3, F, M, R> MessageHandling<(T1, T2, T3), M, R> for F
|
316 |
where
|
317 |
F: FnOnce(T1, T2, T3) -> R,
|
318 |
T1: Serialize + DeserializeOwned,
|
319 |
{
|
320 |
fn message_type() -> &'static str {
|
321 |
type_name::<T1>()
|
322 |
}
|
323 |
}
|
324 |
|
325 |
impl<T1, T2, T3, T4, F, M, R> MessageHandling<(T1, T2, T3, T4), M, R> for F
|
326 |
where
|
327 |
F: FnOnce(T1, T2, T3, T4) -> R,
|
328 |
T1: Serialize + DeserializeOwned,
|
329 |
{
|
330 |
fn message_type() -> &'static str {
|
331 |
type_name::<T1>()
|
332 |
}
|
333 |
}
|
334 |
|