1 |
use std::{any::type_name, collections::HashMap, ops::Deref};
|
2 |
|
3 |
use anyhow::Error;
|
4 |
use futures_util::Future;
|
5 |
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
6 |
use rsa::{
|
7 |
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
|
8 |
pss::{Signature, SigningKey, VerifyingKey},
|
9 |
sha2::Sha256,
|
10 |
signature::{RandomizedSigner, SignatureEncoding, Verifier},
|
11 |
RsaPrivateKey, RsaPublicKey,
|
12 |
};
|
13 |
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
14 |
use serde_json::Value;
|
15 |
|
16 |
use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
|
17 |
|
18 |
use super::{instance::Instance, user::User};
|
19 |
|
20 |
#[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
21 |
pub struct Authenticated<T: Serialize> {
|
22 |
|
23 |
|
24 |
source: Vec<AuthenticationSource>,
|
25 |
message_type: String,
|
26 |
#[serde(flatten)]
|
27 |
message: T,
|
28 |
}
|
29 |
|
30 |
pub trait AuthenticationSourceProvider: Sized {
|
31 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
|
32 |
}
|
33 |
|
34 |
pub trait AuthenticationSourceProviders: Sized {
|
35 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
|
36 |
}
|
37 |
|
38 |
impl<A> AuthenticationSourceProviders for A
|
39 |
where
|
40 |
A: AuthenticationSourceProvider,
|
41 |
{
|
42 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
43 |
vec![self.authenticate(payload)]
|
44 |
}
|
45 |
}
|
46 |
|
47 |
impl<A, B> AuthenticationSourceProviders for (A, B)
|
48 |
where
|
49 |
A: AuthenticationSourceProvider,
|
50 |
B: AuthenticationSourceProvider,
|
51 |
{
|
52 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
53 |
let (first, second) = self;
|
54 |
|
55 |
vec![first.authenticate(payload), second.authenticate(payload)]
|
56 |
}
|
57 |
}
|
58 |
|
59 |
impl<T: Serialize> Authenticated<T> {
|
60 |
pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
|
61 |
let message_payload = serde_json::to_vec(&message).unwrap();
|
62 |
|
63 |
let authentication = auth_sources.authenticate_all(&message_payload);
|
64 |
|
65 |
Self {
|
66 |
source: authentication,
|
67 |
message_type: type_name::<T>().to_string(),
|
68 |
message,
|
69 |
}
|
70 |
}
|
71 |
|
72 |
pub fn new_empty(message: T) -> Self {
|
73 |
Self {
|
74 |
source: vec![],
|
75 |
message_type: type_name::<T>().to_string(),
|
76 |
message,
|
77 |
}
|
78 |
}
|
79 |
|
80 |
pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
|
81 |
let message_payload = serde_json::to_vec(&self.message).unwrap();
|
82 |
|
83 |
self.source
|
84 |
.push(authentication.authenticate(&message_payload));
|
85 |
}
|
86 |
}
|
87 |
|
88 |
mod verified {}
|
89 |
|
90 |
#[derive(Clone, Debug)]
|
91 |
pub struct UserAuthenticator {
|
92 |
pub user: User,
|
93 |
pub token: UserAuthenticationToken,
|
94 |
}
|
95 |
|
96 |
impl AuthenticationSourceProvider for UserAuthenticator {
|
97 |
fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
|
98 |
AuthenticationSource::User {
|
99 |
user: self.user,
|
100 |
token: self.token,
|
101 |
}
|
102 |
}
|
103 |
}
|
104 |
|
105 |
#[derive(Clone)]
|
106 |
pub struct InstanceAuthenticator<'a> {
|
107 |
pub instance: Instance,
|
108 |
pub private_key: &'a str,
|
109 |
}
|
110 |
|
111 |
impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
|
112 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
|
113 |
let mut rng = rand::thread_rng();
|
114 |
|
115 |
let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap();
|
116 |
let signing_key = SigningKey::<Sha256>::new(private_key);
|
117 |
let signature = signing_key.sign_with_rng(&mut rng, &payload);
|
118 |
|
119 |
AuthenticationSource::Instance {
|
120 |
instance: self.instance,
|
121 |
|
122 |
signature: InstanceSignature(signature.to_bytes().into_vec()),
|
123 |
}
|
124 |
}
|
125 |
}
|
126 |
|
127 |
#[repr(transparent)]
|
128 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
129 |
pub struct UserAuthenticationToken(String);
|
130 |
|
131 |
impl From<String> for UserAuthenticationToken {
|
132 |
fn from(value: String) -> Self {
|
133 |
Self(value)
|
134 |
}
|
135 |
}
|
136 |
|
137 |
impl ToString for UserAuthenticationToken {
|
138 |
fn to_string(&self) -> String {
|
139 |
self.0.clone()
|
140 |
}
|
141 |
}
|
142 |
|
143 |
impl AsRef<str> for UserAuthenticationToken {
|
144 |
fn as_ref(&self) -> &str {
|
145 |
&self.0
|
146 |
}
|
147 |
}
|
148 |
|
149 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
150 |
pub struct InstanceSignature(Vec<u8>);
|
151 |
|
152 |
impl AsRef<[u8]> for InstanceSignature {
|
153 |
fn as_ref(&self) -> &[u8] {
|
154 |
&self.0
|
155 |
}
|
156 |
}
|
157 |
|
158 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
159 |
pub enum AuthenticationSource {
|
160 |
User {
|
161 |
user: User,
|
162 |
token: UserAuthenticationToken,
|
163 |
},
|
164 |
Instance {
|
165 |
instance: Instance,
|
166 |
signature: InstanceSignature,
|
167 |
},
|
168 |
}
|
169 |
|
170 |
pub struct NetworkMessage(pub Vec<u8>);
|
171 |
|
172 |
impl Deref for NetworkMessage {
|
173 |
type Target = [u8];
|
174 |
|
175 |
fn deref(&self) -> &Self::Target {
|
176 |
&self.0
|
177 |
}
|
178 |
}
|
179 |
|
180 |
pub struct AuthenticatedUser(pub User);
|
181 |
|
182 |
#[derive(Debug, thiserror::Error)]
|
183 |
pub enum UserAuthenticationError {
|
184 |
#[error("user authentication missing")]
|
185 |
Missing,
|
186 |
|
187 |
|
188 |
#[error("user token was invalid")]
|
189 |
InvalidToken,
|
190 |
#[error("an error has occured")]
|
191 |
Other(#[from] Error),
|
192 |
}
|
193 |
|
194 |
pub struct AuthenticatedInstance(Instance);
|
195 |
|
196 |
impl AuthenticatedInstance {
|
197 |
pub fn inner(&self) -> &Instance {
|
198 |
&self.0
|
199 |
}
|
200 |
}
|
201 |
|
202 |
#[async_trait::async_trait]
|
203 |
pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
|
204 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
|
205 |
}
|
206 |
|
207 |
#[async_trait::async_trait]
|
208 |
impl FromMessage<ConnectionState> for AuthenticatedUser {
|
209 |
async fn from_message(
|
210 |
network_message: &NetworkMessage,
|
211 |
state: &ConnectionState,
|
212 |
) -> Result<Self, Error> {
|
213 |
let message: Authenticated<HashMap<String, Value>> =
|
214 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
215 |
|
216 |
let (auth_user, auth_token) = message
|
217 |
.source
|
218 |
.iter()
|
219 |
.filter_map(|auth| {
|
220 |
if let AuthenticationSource::User { user, token } = auth {
|
221 |
Some((user, token))
|
222 |
} else {
|
223 |
None
|
224 |
}
|
225 |
})
|
226 |
.next()
|
227 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
228 |
|
229 |
let authenticated_instance =
|
230 |
AuthenticatedInstance::from_message(network_message, state).await?;
|
231 |
|
232 |
let public_key_raw = public_key(&auth_user.instance).await?;
|
233 |
let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
|
234 |
|
235 |
let data: TokenData<UserTokenMetadata> = decode(
|
236 |
auth_token.as_ref(),
|
237 |
&verification_key,
|
238 |
&Validation::new(Algorithm::RS256),
|
239 |
)
|
240 |
.unwrap();
|
241 |
|
242 |
if data.claims.user != *auth_user
|
243 |
|| data.claims.generated_for != *authenticated_instance.inner()
|
244 |
{
|
245 |
Err(Error::from(UserAuthenticationError::InvalidToken))
|
246 |
} else {
|
247 |
Ok(AuthenticatedUser(data.claims.user))
|
248 |
}
|
249 |
}
|
250 |
}
|
251 |
|
252 |
#[async_trait::async_trait]
|
253 |
impl FromMessage<ConnectionState> for AuthenticatedInstance {
|
254 |
async fn from_message(
|
255 |
network_message: &NetworkMessage,
|
256 |
state: &ConnectionState,
|
257 |
) -> Result<Self, Error> {
|
258 |
let message: Authenticated<Value> =
|
259 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
260 |
|
261 |
let (instance, signature) = message
|
262 |
.source
|
263 |
.iter()
|
264 |
.filter_map(|auth: &AuthenticationSource| {
|
265 |
if let AuthenticationSource::Instance {
|
266 |
instance,
|
267 |
signature,
|
268 |
} = auth
|
269 |
{
|
270 |
Some((instance, signature))
|
271 |
} else {
|
272 |
None
|
273 |
}
|
274 |
})
|
275 |
.next()
|
276 |
|
277 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
278 |
|
279 |
let public_key = {
|
280 |
let cached_keys = state.cached_keys.read().await;
|
281 |
|
282 |
if let Some(key) = cached_keys.get(&instance) {
|
283 |
key.clone()
|
284 |
} else {
|
285 |
drop(cached_keys);
|
286 |
let mut cached_keys = state.cached_keys.write().await;
|
287 |
let key = public_key(instance).await?;
|
288 |
let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap();
|
289 |
cached_keys.insert(instance.clone(), public_key.clone());
|
290 |
public_key
|
291 |
}
|
292 |
};
|
293 |
|
294 |
let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
|
295 |
|
296 |
let message_json = serde_json::to_vec(&message.message).unwrap();
|
297 |
|
298 |
verifying_key.verify(
|
299 |
&message_json,
|
300 |
&Signature::try_from(signature.as_ref()).unwrap(),
|
301 |
)?;
|
302 |
|
303 |
Ok(AuthenticatedInstance(instance.clone()))
|
304 |
}
|
305 |
}
|
306 |
|
307 |
#[async_trait::async_trait]
|
308 |
impl<S, T> FromMessage<S> for Option<T>
|
309 |
where
|
310 |
T: FromMessage<S>,
|
311 |
S: Send + Sync + 'static,
|
312 |
{
|
313 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
|
314 |
Ok(T::from_message(message, state).await.ok())
|
315 |
}
|
316 |
}
|
317 |
|
318 |
#[async_trait::async_trait]
|
319 |
pub trait MessageHandler<T, S, R> {
|
320 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
|
321 |
}
|
322 |
#[async_trait::async_trait]
|
323 |
impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
|
324 |
where
|
325 |
T: FnOnce(T1) -> F + Clone + Send + 'static,
|
326 |
F: Future<Output = Result<R, E>> + Send,
|
327 |
T1: FromMessage<S> + Send,
|
328 |
S: Send + Sync,
|
329 |
E: std::error::Error + Send + Sync + 'static,
|
330 |
{
|
331 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
332 |
let value = T1::from_message(message, state).await?;
|
333 |
self(value).await.map_err(|e| Error::from(e))
|
334 |
}
|
335 |
}
|
336 |
|
337 |
#[async_trait::async_trait]
|
338 |
impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
|
339 |
where
|
340 |
T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
|
341 |
F: Future<Output = Result<R, E>> + Send,
|
342 |
T1: FromMessage<S> + Send,
|
343 |
T2: FromMessage<S> + Send,
|
344 |
S: Send + Sync,
|
345 |
E: std::error::Error + Send + Sync + 'static,
|
346 |
{
|
347 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
348 |
let value = T1::from_message(message, state).await?;
|
349 |
let value_2 = T2::from_message(message, state).await?;
|
350 |
self(value, value_2).await.map_err(|e| Error::from(e))
|
351 |
}
|
352 |
}
|
353 |
|
354 |
#[async_trait::async_trait]
|
355 |
impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
|
356 |
where
|
357 |
T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
|
358 |
F: Future<Output = Result<R, E>> + Send,
|
359 |
T1: FromMessage<S> + Send,
|
360 |
T2: FromMessage<S> + Send,
|
361 |
T3: FromMessage<S> + Send,
|
362 |
S: Send + Sync,
|
363 |
E: std::error::Error + Send + Sync + 'static,
|
364 |
{
|
365 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
366 |
let value = T1::from_message(message, state).await?;
|
367 |
let value_2 = T2::from_message(message, state).await?;
|
368 |
let value_3 = T3::from_message(message, state).await?;
|
369 |
|
370 |
self(value, value_2, value_3)
|
371 |
.await
|
372 |
.map_err(|e| Error::from(e))
|
373 |
}
|
374 |
}
|
375 |
|
376 |
pub struct State<T>(pub T);
|
377 |
|
378 |
#[async_trait::async_trait]
|
379 |
impl<T> FromMessage<T> for State<T>
|
380 |
where
|
381 |
T: Clone + Send + Sync,
|
382 |
{
|
383 |
async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
|
384 |
Ok(Self(state.clone()))
|
385 |
}
|
386 |
}
|
387 |
|
388 |
|
389 |
#[async_trait::async_trait]
|
390 |
impl<T, S> FromMessage<S> for Message<T>
|
391 |
where
|
392 |
T: DeserializeOwned + Send + Sync + Serialize,
|
393 |
S: Clone + Send + Sync,
|
394 |
{
|
395 |
async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
|
396 |
Ok(Message(serde_json::from_slice(&message)?))
|
397 |
}
|
398 |
}
|
399 |
|
400 |
pub struct Message<T: Serialize + DeserializeOwned>(pub T);
|
401 |
|
402 |
async fn public_key(instance: &Instance) -> Result<String, Error> {
|
403 |
let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
|
404 |
.await?
|
405 |
.text()
|
406 |
.await?;
|
407 |
|
408 |
Ok(key)
|
409 |
}
|
410 |
|