1 |
use std::{any::type_name, collections::HashMap, ops::Deref};
|
2 |
|
3 |
use anyhow::Error;
|
4 |
use futures_util::Future;
|
5 |
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
6 |
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
7 |
use serde_json::Value;
|
8 |
|
9 |
use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
|
10 |
|
11 |
use super::{instance::Instance, user::User};
|
12 |
|
13 |
#[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
14 |
pub struct Authenticated<T: Serialize> {
|
15 |
#[serde(flatten)]
|
16 |
source: Vec<AuthenticationSource>,
|
17 |
message_type: String,
|
18 |
#[serde(flatten)]
|
19 |
message: T,
|
20 |
}
|
21 |
|
22 |
pub trait AuthenticationSourceProvider: Sized {
|
23 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
|
24 |
}
|
25 |
|
26 |
pub trait AuthenticationSourceProviders: Sized {
|
27 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
|
28 |
}
|
29 |
|
30 |
impl<A> AuthenticationSourceProviders for A
|
31 |
where
|
32 |
A: AuthenticationSourceProvider,
|
33 |
{
|
34 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
35 |
vec![self.authenticate(payload)]
|
36 |
}
|
37 |
}
|
38 |
|
39 |
impl<A, B> AuthenticationSourceProviders for (A, B)
|
40 |
where
|
41 |
A: AuthenticationSourceProvider,
|
42 |
B: AuthenticationSourceProvider,
|
43 |
{
|
44 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
45 |
let (first, second) = self;
|
46 |
|
47 |
vec![first.authenticate(payload), second.authenticate(payload)]
|
48 |
}
|
49 |
}
|
50 |
|
51 |
impl<T: Serialize> Authenticated<T> {
|
52 |
pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
|
53 |
let message_payload = serde_json::to_vec(&message).unwrap();
|
54 |
|
55 |
let authentication = auth_sources.authenticate_all(&message_payload);
|
56 |
|
57 |
Self {
|
58 |
source: authentication,
|
59 |
message_type: type_name::<T>().to_string(),
|
60 |
message,
|
61 |
}
|
62 |
}
|
63 |
|
64 |
pub fn new_empty(message: T) -> Self {
|
65 |
Self {
|
66 |
source: vec![],
|
67 |
message_type: type_name::<T>().to_string(),
|
68 |
message,
|
69 |
}
|
70 |
}
|
71 |
|
72 |
pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
|
73 |
let message_payload = serde_json::to_vec(&self.message).unwrap();
|
74 |
|
75 |
self.source
|
76 |
.push(authentication.authenticate(&message_payload));
|
77 |
}
|
78 |
}
|
79 |
|
80 |
mod verified {}
|
81 |
|
82 |
#[derive(Clone, Debug)]
|
83 |
pub struct UserAuthenticator {
|
84 |
pub user: User,
|
85 |
pub token: UserAuthenticationToken,
|
86 |
}
|
87 |
|
88 |
impl AuthenticationSourceProvider for UserAuthenticator {
|
89 |
fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
|
90 |
AuthenticationSource::User {
|
91 |
user: self.user,
|
92 |
token: self.token,
|
93 |
}
|
94 |
}
|
95 |
}
|
96 |
|
97 |
#[derive(Clone)]
|
98 |
pub struct InstanceAuthenticator<'a> {
|
99 |
pub instance: Instance,
|
100 |
pub private_key: &'a str,
|
101 |
}
|
102 |
|
103 |
impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
|
104 |
fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
|
105 |
todo!()
|
106 |
}
|
107 |
}
|
108 |
|
109 |
#[repr(transparent)]
|
110 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
111 |
pub struct UserAuthenticationToken(String);
|
112 |
|
113 |
impl From<String> for UserAuthenticationToken {
|
114 |
fn from(value: String) -> Self {
|
115 |
Self(value)
|
116 |
}
|
117 |
}
|
118 |
|
119 |
impl ToString for UserAuthenticationToken {
|
120 |
fn to_string(&self) -> String {
|
121 |
self.0.clone()
|
122 |
}
|
123 |
}
|
124 |
|
125 |
impl AsRef<str> for UserAuthenticationToken {
|
126 |
fn as_ref(&self) -> &str {
|
127 |
&self.0
|
128 |
}
|
129 |
}
|
130 |
|
131 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
132 |
pub struct InstanceSignature(Vec<u8>);
|
133 |
|
134 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
135 |
pub enum AuthenticationSource {
|
136 |
User {
|
137 |
user: User,
|
138 |
token: UserAuthenticationToken,
|
139 |
},
|
140 |
Instance {
|
141 |
instance: Instance,
|
142 |
signature: InstanceSignature,
|
143 |
},
|
144 |
}
|
145 |
|
146 |
pub struct NetworkMessage(pub Vec<u8>);
|
147 |
|
148 |
impl Deref for NetworkMessage {
|
149 |
type Target = [u8];
|
150 |
|
151 |
fn deref(&self) -> &Self::Target {
|
152 |
&self.0
|
153 |
}
|
154 |
}
|
155 |
|
156 |
pub struct AuthenticatedUser(pub User);
|
157 |
|
158 |
#[derive(Debug, thiserror::Error)]
|
159 |
pub enum UserAuthenticationError {
|
160 |
#[error("user authentication missing")]
|
161 |
Missing,
|
162 |
|
163 |
|
164 |
#[error("user token was invalid")]
|
165 |
InvalidToken,
|
166 |
#[error("an error has occured")]
|
167 |
Other(#[from] Error),
|
168 |
}
|
169 |
|
170 |
pub struct AuthenticatedInstance(Instance);
|
171 |
|
172 |
impl AuthenticatedInstance {
|
173 |
pub fn inner(&self) -> &Instance {
|
174 |
&self.0
|
175 |
}
|
176 |
}
|
177 |
|
178 |
#[async_trait::async_trait]
|
179 |
pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
|
180 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
|
181 |
}
|
182 |
|
183 |
#[async_trait::async_trait]
|
184 |
impl FromMessage<ConnectionState> for AuthenticatedUser {
|
185 |
async fn from_message(
|
186 |
network_message: &NetworkMessage,
|
187 |
state: &ConnectionState,
|
188 |
) -> Result<Self, Error> {
|
189 |
let message: Authenticated<HashMap<String, Value>> =
|
190 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
191 |
|
192 |
let (auth_user, auth_token) = message
|
193 |
.source
|
194 |
.iter()
|
195 |
.filter_map(|auth| {
|
196 |
if let AuthenticationSource::User { user, token } = auth {
|
197 |
Some((user, token))
|
198 |
} else {
|
199 |
None
|
200 |
}
|
201 |
})
|
202 |
.next()
|
203 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
204 |
|
205 |
let authenticated_instance =
|
206 |
AuthenticatedInstance::from_message(network_message, state).await?;
|
207 |
|
208 |
let public_key_raw = public_key(&auth_user.instance).await?;
|
209 |
let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
|
210 |
|
211 |
let data: TokenData<UserTokenMetadata> = decode(
|
212 |
auth_token.as_ref(),
|
213 |
&verification_key,
|
214 |
&Validation::new(Algorithm::RS256),
|
215 |
)
|
216 |
.unwrap();
|
217 |
|
218 |
if data.claims.user != *auth_user
|
219 |
|| data.claims.generated_for != *authenticated_instance.inner()
|
220 |
{
|
221 |
Err(Error::from(UserAuthenticationError::InvalidToken))
|
222 |
} else {
|
223 |
Ok(AuthenticatedUser(data.claims.user))
|
224 |
}
|
225 |
}
|
226 |
}
|
227 |
|
228 |
#[async_trait::async_trait]
|
229 |
impl FromMessage<ConnectionState> for AuthenticatedInstance {
|
230 |
async fn from_message(
|
231 |
_message: &NetworkMessage,
|
232 |
_state: &ConnectionState,
|
233 |
) -> Result<Self, Error> {
|
234 |
todo!()
|
235 |
}
|
236 |
}
|
237 |
|
238 |
#[async_trait::async_trait]
|
239 |
impl<S, T> FromMessage<S> for Option<T>
|
240 |
where
|
241 |
T: FromMessage<S>,
|
242 |
S: Send + Sync + 'static,
|
243 |
{
|
244 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
|
245 |
Ok(T::from_message(message, state).await.ok())
|
246 |
}
|
247 |
}
|
248 |
|
249 |
#[async_trait::async_trait]
|
250 |
pub trait MessageHandler<T, S, R> {
|
251 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
|
252 |
}
|
253 |
#[async_trait::async_trait]
|
254 |
impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
|
255 |
where
|
256 |
T: FnOnce(T1) -> F + Clone + Send + 'static,
|
257 |
F: Future<Output = Result<R, E>> + Send,
|
258 |
T1: FromMessage<S> + Send,
|
259 |
S: Send + Sync,
|
260 |
E: std::error::Error + Send + Sync + 'static,
|
261 |
{
|
262 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
263 |
let value = T1::from_message(message, state).await?;
|
264 |
self(value).await.map_err(|e| Error::from(e))
|
265 |
}
|
266 |
}
|
267 |
|
268 |
#[async_trait::async_trait]
|
269 |
impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
|
270 |
where
|
271 |
T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
|
272 |
F: Future<Output = Result<R, E>> + Send,
|
273 |
T1: FromMessage<S> + Send,
|
274 |
T2: FromMessage<S> + Send,
|
275 |
S: Send + Sync,
|
276 |
E: std::error::Error + Send + Sync + 'static,
|
277 |
{
|
278 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
279 |
let value = T1::from_message(message, state).await?;
|
280 |
let value_2 = T2::from_message(message, state).await?;
|
281 |
self(value, value_2).await.map_err(|e| Error::from(e))
|
282 |
}
|
283 |
}
|
284 |
|
285 |
#[async_trait::async_trait]
|
286 |
impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
|
287 |
where
|
288 |
T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
|
289 |
F: Future<Output = Result<R, E>> + Send,
|
290 |
T1: FromMessage<S> + Send,
|
291 |
T2: FromMessage<S> + Send,
|
292 |
T3: FromMessage<S> + Send,
|
293 |
S: Send + Sync,
|
294 |
E: std::error::Error + Send + Sync + 'static,
|
295 |
{
|
296 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
297 |
let value = T1::from_message(message, state).await?;
|
298 |
let value_2 = T2::from_message(message, state).await?;
|
299 |
let value_3 = T3::from_message(message, state).await?;
|
300 |
|
301 |
self(value, value_2, value_3)
|
302 |
.await
|
303 |
.map_err(|e| Error::from(e))
|
304 |
}
|
305 |
}
|
306 |
|
307 |
pub struct State<T>(pub T);
|
308 |
|
309 |
#[async_trait::async_trait]
|
310 |
impl<T> FromMessage<T> for State<T>
|
311 |
where
|
312 |
T: Clone + Send + Sync,
|
313 |
{
|
314 |
async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
|
315 |
Ok(Self(state.clone()))
|
316 |
}
|
317 |
}
|
318 |
|
319 |
|
320 |
#[async_trait::async_trait]
|
321 |
impl<T, S> FromMessage<S> for Message<T>
|
322 |
where
|
323 |
T: DeserializeOwned + Send + Sync + Serialize,
|
324 |
S: Clone + Send + Sync,
|
325 |
{
|
326 |
async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
|
327 |
Ok(Message(serde_json::from_slice(&message)?))
|
328 |
}
|
329 |
}
|
330 |
|
331 |
pub struct Message<T: Serialize + DeserializeOwned>(pub T);
|
332 |
|
333 |
async fn public_key(instance: &Instance) -> Result<String, Error> {
|
334 |
let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
|
335 |
.await?
|
336 |
.text()
|
337 |
.await?;
|
338 |
|
339 |
Ok(key)
|
340 |
}
|
341 |
|