Finish connection refactor!
parent: tbd commit: bd675cd
Showing 5 changed files with 72 insertions and 25 deletions
Cargo.lock
@@ -629,7 +629,7 @@ dependencies = [ | ||
629 | 629 | |
630 | 630 | [[package]] |
631 | 631 | name = "giterated-daemon" |
632 | version = "0.0.5" | |
632 | version = "0.0.6" | |
633 | 633 | dependencies = [ |
634 | 634 | "aes-gcm", |
635 | 635 | "anyhow", |
Cargo.toml
@@ -1,6 +1,6 @@ | ||
1 | 1 | [package] |
2 | 2 | name = "giterated-daemon" |
3 | version = "0.0.5" | |
3 | version = "0.0.6" | |
4 | 4 | edition = "2021" |
5 | 5 | |
6 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html |
src/connection/wrapper.rs
@@ -1,10 +1,13 @@ | ||
1 | 1 | use std::{ |
2 | 2 | net::SocketAddr, |
3 | sync::{atomic::AtomicBool, Arc}, | |
3 | sync::{ | |
4 | atomic::{AtomicBool, Ordering}, | |
5 | Arc, | |
6 | }, | |
4 | 7 | }; |
5 | 8 | |
6 | 9 | use anyhow::Error; |
7 | use futures_util::SinkExt; | |
10 | use futures_util::{SinkExt, StreamExt}; | |
8 | 11 | use serde::Serialize; |
9 | 12 | use tokio::{net::TcpStream, sync::Mutex}; |
10 | 13 | use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; |
@@ -14,10 +17,13 @@ use crate::{ | ||
14 | 17 | backend::{DiscoveryBackend, RepositoryBackend, UserBackend}, |
15 | 18 | connection::ConnectionError, |
16 | 19 | listener::Listeners, |
17 | model::instance::Instance, | |
20 | model::{authenticated::NetworkMessage, instance::Instance}, | |
18 | 21 | }; |
19 | 22 | |
20 | use super::{connection_worker, Connections}; | |
23 | use super::{ | |
24 | authentication::authentication_handle, connection_worker, handshake::handshake_handle, | |
25 | repository::repository_handle, user::user_handle, Connections, | |
26 | }; | |
21 | 27 | |
22 | 28 | pub async fn connection_wrapper( |
23 | 29 | mut socket: WebSocketStream<TcpStream>, |
@@ -28,26 +34,68 @@ pub async fn connection_wrapper( | ||
28 | 34 | auth_granter: Arc<Mutex<AuthenticationTokenGranter>>, |
29 | 35 | discovery_backend: Arc<Mutex<dyn DiscoveryBackend + Send>>, |
30 | 36 | addr: SocketAddr, |
37 | instance: impl ToOwned<Owned = Instance>, | |
31 | 38 | ) { |
39 | let mut connection_state = ConnectionState { | |
40 | socket: Arc::new(Mutex::new(socket)), | |
41 | listeners, | |
42 | connections, | |
43 | repository_backend, | |
44 | user_backend, | |
45 | auth_granter, | |
46 | discovery_backend, | |
47 | addr, | |
48 | instance: instance.to_owned(), | |
49 | handshaked: Arc::new(AtomicBool::new(false)), | |
50 | }; | |
51 | ||
32 | 52 | let mut handshaked = false; |
53 | ||
33 | 54 | loop { |
34 | if let Err(e) = connection_worker( | |
35 | &mut socket, | |
36 | &mut handshaked, | |
37 | &listeners, | |
38 | &connections, | |
39 | &repository_backend, | |
40 | &user_backend, | |
41 | &auth_granter, | |
42 | &discovery_backend, | |
43 | &addr, | |
44 | ) | |
45 | .await | |
46 | { | |
47 | error!("Error handling message: {:?}", e); | |
55 | let mut socket = connection_state.socket.lock().await; | |
56 | let message = socket.next().await; | |
57 | drop(socket); | |
48 | 58 | |
49 | if let ConnectionError::Shutdown = &e { | |
50 | info!("Closing connection {}", addr); | |
59 | match message { | |
60 | Some(Ok(message)) => { | |
61 | let payload = match message { | |
62 | Message::Binary(payload) => payload, | |
63 | Message::Ping(_) => { | |
64 | let mut socket = connection_state.socket.lock().await; | |
65 | socket.send(Message::Pong(vec![])).await; | |
66 | drop(socket); | |
67 | continue; | |
68 | } | |
69 | Message::Close(_) => return, | |
70 | _ => continue, | |
71 | }; | |
72 | ||
73 | let message = NetworkMessage(payload); | |
74 | ||
75 | if !handshaked { | |
76 | if handshake_handle(&message, &connection_state).await.is_ok() { | |
77 | if connection_state.handshaked.load(Ordering::SeqCst) { | |
78 | handshaked = true; | |
79 | } | |
80 | } | |
81 | } else { | |
82 | if authentication_handle(&message, &connection_state) | |
83 | .await | |
84 | .is_ok() | |
85 | { | |
86 | continue; | |
87 | } else if repository_handle(&message, &connection_state).await.is_ok() { | |
88 | continue; | |
89 | } else if user_handle(&message, &connection_state).await.is_ok() { | |
90 | continue; | |
91 | } else { | |
92 | error!("Message completely unhandled"); | |
93 | continue; | |
94 | } | |
95 | } | |
96 | } | |
97 | _ => { | |
98 | error!("Closing connection for {}", addr); | |
51 | 99 | return; |
52 | 100 | } |
53 | 101 | } |
src/lib.rs
@@ -10,8 +10,6 @@ pub mod listener; | ||
10 | 10 | pub mod messages; |
11 | 11 | pub mod model; |
12 | 12 | |
13 | pub(crate) use std::error::Error as StdError; | |
14 | ||
15 | 13 | #[macro_use] |
16 | 14 | extern crate tracing; |
17 | 15 | |
@@ -20,7 +18,7 @@ pub fn version() -> Version { | ||
20 | 18 | } |
21 | 19 | |
22 | 20 | pub fn validate_version(other: &Version) -> bool { |
23 | let version_req = VersionReq::from_str("=0.0.5").unwrap(); | |
21 | let version_req = VersionReq::from_str("=0.0.6").unwrap(); | |
24 | 22 | |
25 | 23 | version_req.matches(other) |
26 | 24 | } |