JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fix authenticated endpoints

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨1400b06

Showing ⁨⁨8⁩ changed files⁩ with ⁨⁨124⁩ insertions⁩ and ⁨⁨68⁩ deletions⁩

Cargo.lock

View file
@@ -163,6 +163,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
163 163 checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
164 164
165 165 [[package]]
166 name = "bincode"
167 version = "1.3.3"
168 source = "registry+https://github.com/rust-lang/crates.io-index"
169 checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
170 dependencies = [
171 "serde",
172 ]
173
174 [[package]]
166 175 name = "bitflags"
167 176 version = "1.3.2"
168 177 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -678,6 +687,7 @@ dependencies = [
678 687 "argon2",
679 688 "async-trait",
680 689 "base64 0.21.3",
690 "bincode",
681 691 "chrono",
682 692 "deadpool",
683 693 "futures-util",
@@ -711,6 +721,7 @@ dependencies = [
711 721 "argon2",
712 722 "async-trait",
713 723 "base64 0.21.3",
724 "bincode",
714 725 "chrono",
715 726 "futures-util",
716 727 "git2",

giterated-daemon/Cargo.toml

View file
@@ -26,6 +26,7 @@ tower = "*"
26 26 giterated-models = { path = "../giterated-models" }
27 27 giterated-api = { path = "../../giterated-api" }
28 28 deadpool = "*"
29 bincode = "*"
29 30
30 31 toml = { version = "0.7" }
31 32

giterated-daemon/src/connection/forwarded.rs

View file
@@ -1,12 +1,12 @@
1 1 use futures_util::{SinkExt, StreamExt};
2 2 use giterated_api::DaemonConnectionPool;
3 use giterated_models::{messages::error::ConnectionError, model::authenticated::Authenticated};
3 use giterated_models::{messages::error::ConnectionError, model::authenticated::{Authenticated, AuthenticatedPayload}};
4 4 use serde::Serialize;
5 5 use tokio_tungstenite::tungstenite::Message;
6 6
7 pub async fn wrap_forwarded<T: Serialize>(
7 pub async fn wrap_forwarded(
8 8 pool: &DaemonConnectionPool,
9 message: Authenticated<T>,
9 message: AuthenticatedPayload,
10 10 ) -> Message {
11 11 let connection = pool.get().await;
12 12

giterated-daemon/src/connection/handshake.rs

View file
@@ -8,7 +8,7 @@ use semver::Version;
8 8
9 9 use crate::{
10 10 connection::ConnectionError,
11 message::{Message, MessageHandler, NetworkMessage, State},
11 message::{Message, MessageHandler, NetworkMessage, State, HandshakeMessage},
12 12 validate_version, version,
13 13 };
14 14
@@ -42,9 +42,10 @@ pub async fn handshake_handle(
42 42 }
43 43
44 44 async fn initiate_handshake(
45 Message(initiation): Message<InitiateHandshake>,
45 HandshakeMessage(initiation): HandshakeMessage<InitiateHandshake>,
46 46 State(connection_state): State<ConnectionState>,
47 47 ) -> Result<(), HandshakeError> {
48 info!("meow!");
48 49 connection_state
49 50 .send(HandshakeResponse {
50 51 identity: connection_state.instance.clone(),
@@ -81,7 +82,7 @@ async fn initiate_handshake(
81 82 }
82 83
83 84 async fn handshake_response(
84 Message(response): Message<HandshakeResponse>,
85 HandshakeMessage(initiation): HandshakeMessage<HandshakeResponse>,
85 86 State(connection_state): State<ConnectionState>,
86 87 ) -> Result<(), HandshakeError> {
87 88 connection_state
@@ -114,7 +115,7 @@ async fn handshake_response(
114 115 }
115 116
116 117 async fn handshake_finalize(
117 Message(finalize): Message<HandshakeFinalize>,
118 HandshakeMessage(finalize): HandshakeMessage<HandshakeFinalize>,
118 119 State(connection_state): State<ConnectionState>,
119 120 ) -> Result<(), HandshakeError> {
120 121 connection_state.handshaked.store(true, Ordering::SeqCst);

giterated-daemon/src/connection/wrapper.rs

View file
@@ -11,7 +11,7 @@ use anyhow::Error;
11 11 use futures_util::{SinkExt, StreamExt};
12 12 use giterated_models::{
13 13 messages::error::ConnectionError,
14 model::{authenticated::Authenticated, instance::Instance},
14 model::{authenticated::{Authenticated, AuthenticatedPayload}, instance::Instance},
15 15 };
16 16 use rsa::RsaPublicKey;
17 17 use serde::Serialize;
@@ -86,13 +86,14 @@ pub async fn connection_wrapper(
86 86 let message = NetworkMessage(payload.clone());
87 87
88 88 if !handshaked {
89 info!("im foo baring");
89 90 if handshake_handle(&message, &connection_state).await.is_ok() {
90 91 if connection_state.handshaked.load(Ordering::SeqCst) {
91 92 handshaked = true;
92 93 }
93 94 }
94 95 } else {
95 let raw = serde_json::from_slice::<Authenticated<Value>>(&payload).unwrap();
96 let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
96 97
97 98 if let Some(target_instance) = &raw.target_instance {
98 99 // Forward request
@@ -116,7 +117,9 @@ pub async fn connection_wrapper(
116 117
117 118 match authentication_handle(message_type, &message, &connection_state).await {
118 119 Err(e) => {
119 let _ = connection_state.send(ConnectionError(e.to_string())).await;
120 let _ = connection_state
121 .send_raw(ConnectionError(e.to_string()))
122 .await;
120 123 }
121 124 Ok(true) => continue,
122 125 Ok(false) => {}
@@ -124,7 +127,9 @@ pub async fn connection_wrapper(
124 127
125 128 match repository_handle(message_type, &message, &connection_state).await {
126 129 Err(e) => {
127 let _ = connection_state.send(ConnectionError(e.to_string())).await;
130 let _ = connection_state
131 .send_raw(ConnectionError(e.to_string()))
132 .await;
128 133 }
129 134 Ok(true) => continue,
130 135 Ok(false) => {}
@@ -132,7 +137,9 @@ pub async fn connection_wrapper(
132 137
133 138 match user_handle(message_type, &message, &connection_state).await {
134 139 Err(e) => {
135 let _ = connection_state.send(ConnectionError(e.to_string())).await;
140 let _ = connection_state
141 .send_raw(ConnectionError(e.to_string()))
142 .await;
136 143 }
137 144 Ok(true) => continue,
138 145 Ok(false) => {}
@@ -140,7 +147,9 @@ pub async fn connection_wrapper(
140 147
141 148 match authentication_handle(message_type, &message, &connection_state).await {
142 149 Err(e) => {
143 let _ = connection_state.send(ConnectionError(e.to_string())).await;
150 let _ = connection_state
151 .send_raw(ConnectionError(e.to_string()))
152 .await;
144 153 }
145 154 Ok(true) => continue,
146 155 Ok(false) => {}
@@ -189,4 +198,16 @@ impl ConnectionState {
189 198
190 199 Ok(())
191 200 }
201
202 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
203 let payload = serde_json::to_string(&message)?;
204 info!("Sending payload: {}", &payload);
205 self.socket
206 .lock()
207 .await
208 .send(Message::Binary(payload.into_bytes()))
209 .await?;
210
211 Ok(())
212 }
192 213 }

giterated-daemon/src/message.rs

View file
@@ -3,7 +3,7 @@ use std::{collections::HashMap, ops::Deref};
3 3 use anyhow::Error;
4 4 use futures_util::Future;
5 5 use giterated_models::model::{
6 authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata},
6 authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata, AuthenticatedPayload},
7 7 instance::Instance,
8 8 user::User,
9 9 };
@@ -63,7 +63,7 @@ impl FromMessage<ConnectionState> for AuthenticatedUser {
63 63 network_message: &NetworkMessage,
64 64 state: &ConnectionState,
65 65 ) -> Result<Self, Error> {
66 let message: Authenticated<HashMap<String, Value>> =
66 let message: AuthenticatedPayload =
67 67 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
68 68
69 69 let (auth_user, auth_token) = message
@@ -108,9 +108,9 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance {
108 108 network_message: &NetworkMessage,
109 109 state: &ConnectionState,
110 110 ) -> Result<Self, Error> {
111 let message: Authenticated<HashMap<String, Value>> =
111 let message: AuthenticatedPayload =
112 112 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
113
113
114 114 let (instance, signature) = message
115 115 .source
116 116 .iter()
@@ -137,7 +137,7 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance {
137 137 } else {
138 138 drop(cached_keys);
139 139 let mut cached_keys = state.cached_keys.write().await;
140 let key = public_key(instance).await?;
140 let key = public_key(&instance).await?;
141 141 let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap();
142 142 cached_keys.insert(instance.clone(), public_key.clone());
143 143 public_key
@@ -146,15 +146,8 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance {
146 146
147 147 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
148 148
149 let message_json = serde_json::to_vec(&message.message).unwrap();
150
151 info!(
152 "Verification against: {}",
153 std::str::from_utf8(&message_json).unwrap()
154 );
155
156 149 verifying_key.verify(
157 &message_json,
150 &message.payload,
158 151 &Signature::try_from(signature.as_ref()).unwrap(),
159 152 )?;
160 153
@@ -251,7 +244,8 @@ where
251 244 S: Clone + Send + Sync,
252 245 {
253 246 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
254 Ok(Message(serde_json::from_slice(&message)?))
247 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
248 Ok(Message(bincode::deserialize(&payload.payload)?))
255 249 }
256 250 }
257 251
@@ -265,3 +259,20 @@ async fn public_key(instance: &Instance) -> Result<String, Error> {
265 259
266 260 Ok(key)
267 261 }
262
263 /// Handshake-specific message type.
264 ///
265 /// Uses basic serde_json-based deserialization to maintain the highest
266 /// level of compatibility across versions.
267 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
268
269 #[async_trait::async_trait]
270 impl<T, S> FromMessage<S> for HandshakeMessage<T>
271 where
272 T: DeserializeOwned + Send + Sync + Serialize,
273 S: Clone + Send + Sync,
274 {
275 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
276 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
277 }
278 }
278 \ No newline at end of file

giterated-models/Cargo.toml

View file
@@ -23,6 +23,7 @@ argon2 = "*"
23 23 aes-gcm = "0.10.2"
24 24 semver = {version = "*", features = ["serde"]}
25 25 tower = "*"
26 bincode = "*"
26 27
27 28 toml = { version = "0.7" }
28 29

giterated-models/src/model/authenticated.rs

View file
@@ -1,4 +1,4 @@
1 use std::any::type_name;
1 use std::{any::type_name, fmt::Debug};
2 2
3 3 use rsa::{
4 4 pkcs1::DecodeRsaPrivateKey,
@@ -19,28 +19,52 @@ pub struct UserTokenMetadata {
19 19 pub exp: u64,
20 20 }
21 21
22 #[derive(Debug)]
23 pub struct Authenticated<'a, T: Serialize> {
24 pub target_instance: Option<Instance>,
25 pub source: Vec<&'a dyn AuthenticationSourceProvider>,
26 pub message_type: String,
27 pub message: T,
28 }
29
22 30 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
23 pub struct Authenticated<T: Serialize> {
31 pub struct AuthenticatedPayload {
24 32 pub target_instance: Option<Instance>,
25 33 pub source: Vec<AuthenticationSource>,
26 34 pub message_type: String,
27 #[serde(flatten)]
28 pub message: T,
35 pub payload: Vec<u8>,
36 }
37
38 impl<'a, T: Serialize> From<Authenticated<'a, T>> for AuthenticatedPayload {
39 fn from(mut value: Authenticated<'a, T>) -> Self {
40 let payload = bincode::serialize(&value.message).unwrap();
41
42 AuthenticatedPayload {
43 target_instance: value.target_instance,
44 source: value
45 .source
46 .drain(..)
47 .map(|provider| provider.authenticate(&payload))
48 .collect::<Vec<_>>(),
49 message_type: value.message_type,
50 payload,
51 }
52 }
29 53 }
30 54
31 pub trait AuthenticationSourceProvider: Sized {
32 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
55 pub trait AuthenticationSourceProvider: Debug {
56 fn authenticate(&self, payload: &Vec<u8>) -> AuthenticationSource;
33 57 }
34 58
35 pub trait AuthenticationSourceProviders: Sized {
36 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
59 pub trait AuthenticationSourceProviders: Debug {
60 fn authenticate_all(&self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
37 61 }
38 62
39 63 impl<A> AuthenticationSourceProviders for A
40 64 where
41 65 A: AuthenticationSourceProvider,
42 66 {
43 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
67 fn authenticate_all(&self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
44 68 vec![self.authenticate(payload)]
45 69 }
46 70 }
@@ -50,43 +74,26 @@ where
50 74 A: AuthenticationSourceProvider,
51 75 B: AuthenticationSourceProvider,
52 76 {
53 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
77 fn authenticate_all(&self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
54 78 let (first, second) = self;
55 79
56 80 vec![first.authenticate(payload), second.authenticate(payload)]
57 81 }
58 82 }
59 83
60 impl<T: Serialize> Authenticated<T> {
61 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
62 let message_payload = serde_json::to_vec(&message).unwrap();
63
64 let authentication = auth_sources.authenticate_all(&message_payload);
65
84 impl<'a, T: Serialize + Debug> Authenticated<'a, T> {
85 pub fn new(message: T) -> Self {
66 86 Self {
67 source: authentication,
87 source: vec![],
68 88 message_type: type_name::<T>().to_string(),
69 89 message,
70 90 target_instance: None,
71 91 }
72 92 }
73 93
74 pub fn new_for(
75 instance: impl ToOwned<Owned = Instance>,
76 message: T,
77 auth_sources: impl AuthenticationSourceProvider,
78 ) -> Self {
79 let message_payload = serde_json::to_vec(&message).unwrap();
80
81 info!(
82 "Verifying payload: {}",
83 std::str::from_utf8(&message_payload).unwrap()
84 );
85
86 let authentication = auth_sources.authenticate_all(&message_payload);
87
94 pub fn new_for(instance: impl ToOwned<Owned = Instance>, message: T) -> Self {
88 95 Self {
89 source: authentication,
96 source: vec![],
90 97 message_type: type_name::<T>().to_string(),
91 98 message,
92 99 target_instance: Some(instance.to_owned()),
@@ -102,7 +109,7 @@ impl<T: Serialize> Authenticated<T> {
102 109 }
103 110 }
104 111
105 pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
112 pub fn append_authentication(&mut self, authentication: &'a dyn AuthenticationSourceProvider) {
106 113 let message_payload = serde_json::to_vec(&self.message).unwrap();
107 114
108 115 info!(
@@ -110,8 +117,11 @@ impl<T: Serialize> Authenticated<T> {
110 117 std::str::from_utf8(&message_payload).unwrap()
111 118 );
112 119
113 self.source
114 .push(authentication.authenticate(&message_payload));
120 self.source.push(authentication);
121 }
122
123 pub fn into_payload(self) -> AuthenticatedPayload {
124 self.into()
115 125 }
116 126 }
117 127
@@ -124,22 +134,22 @@ pub struct UserAuthenticator {
124 134 }
125 135
126 136 impl AuthenticationSourceProvider for UserAuthenticator {
127 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
137 fn authenticate(&self, _payload: &Vec<u8>) -> AuthenticationSource {
128 138 AuthenticationSource::User {
129 user: self.user,
130 token: self.token,
139 user: self.user.clone(),
140 token: self.token.clone(),
131 141 }
132 142 }
133 143 }
134 144
135 #[derive(Clone)]
145 #[derive(Debug, Clone)]
136 146 pub struct InstanceAuthenticator<'a> {
137 147 pub instance: Instance,
138 148 pub private_key: &'a str,
139 149 }
140 150
141 151 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
142 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
152 fn authenticate(&self, payload: &Vec<u8>) -> AuthenticationSource {
143 153 let mut rng = rand::thread_rng();
144 154
145 155 let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap();
@@ -147,7 +157,7 @@ impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
147 157 let signature = signing_key.sign_with_rng(&mut rng, &payload);
148 158
149 159 AuthenticationSource::Instance {
150 instance: self.instance,
160 instance: self.instance.clone(),
151 161 // TODO: Actually parse signature from private key
152 162 signature: InstanceSignature(signature.to_bytes().into_vec()),
153 163 }