diff --git a/Cargo.lock b/Cargo.lock index d746042..2415e98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,6 +54,17 @@ dependencies = [ [[package]] name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + +[[package]] +name = "ahash" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" @@ -785,6 +796,7 @@ dependencies = [ "thiserror", "tokio", "tokio-tungstenite", + "tokio-util", "toml", "tracing", "tracing-subscriber", @@ -812,6 +824,7 @@ dependencies = [ "thiserror", "toml", "tracing", + "url", ] [[package]] @@ -852,6 +865,9 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash 0.7.6", +] [[package]] name = "hashbrown" @@ -859,7 +875,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" dependencies = [ - "ahash", + "ahash 0.8.3", "allocator-api2", ] @@ -2092,7 +2108,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd4cef4251aabbae751a3710927945901ee1d97ee96d757f6880ebb9a79bfd53" dependencies = [ - "ahash", + "ahash 0.8.3", "atoi", "byteorder", "bytes", @@ -2492,13 +2508,15 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.12.3", "pin-project-lite", "tokio", "tracing", @@ -2703,6 +2721,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/giterated-daemon/Cargo.toml b/giterated-daemon/Cargo.toml index 32b13cd..619dab7 100644 --- a/giterated-daemon/Cargo.toml +++ b/giterated-daemon/Cargo.toml @@ -27,6 +27,7 @@ giterated-api = { path = "../../giterated-api" } giterated-stack = { path = "../giterated-stack" } deadpool = "*" bincode = "*" +tokio-util = {version = "0.7.9", features = ["rt"]} toml = { version = "0.7" } diff --git a/giterated-daemon/src/authorization.rs b/giterated-daemon/src/authorization.rs index 2f65f81..1bb1fe5 100644 --- a/giterated-daemon/src/authorization.rs +++ b/giterated-daemon/src/authorization.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use crate::connection::wrapper::ConnectionState; use giterated_models::error::OperationError; @@ -10,13 +12,10 @@ use giterated_models::repository::{ use giterated_models::user::User; -use giterated_models::{ - object::ObjectRequest, - settings::SetSetting, - value::{GetValue, GiteratedObjectValue}, -}; +use giterated_models::value::GetValueTyped; +use giterated_models::{object::ObjectRequest, settings::SetSetting, value::GiteratedObjectValue}; #[async_trait::async_trait] -pub trait AuthorizedOperation { +pub trait AuthorizedOperation { /// Authorizes the operation, returning whether the operation was /// authorized or not. async fn authorize( @@ -54,8 +53,8 @@ impl AuthorizedOperation for SetSetting { } #[async_trait::async_trait] -impl AuthorizedOperation - for GetValue +impl + AuthorizedOperation for GetValueTyped { async fn authorize( &self, diff --git a/giterated-daemon/src/backend/git.rs b/giterated-daemon/src/backend/git.rs index 2f1fdcb..cb33d91 100644 --- a/giterated-daemon/src/backend/git.rs +++ b/giterated-daemon/src/backend/git.rs @@ -52,10 +52,6 @@ impl GitRepository { user: &Option, settings: &Arc>, ) -> bool { - info!( - "Can user {:?} view repository {}/{}?", - user, self.owner_user, self.name - ); if matches!(self.visibility, RepositoryVisibility::Public) { return true; } @@ -73,7 +69,6 @@ impl GitRepository { if matches!(self.visibility, RepositoryVisibility::Private) { // Check if the user can view - info!("private"); let mut settings = settings.lock().await; let access_list = settings @@ -87,8 +82,6 @@ impl GitRepository { ) .await; - info!("Access list returned"); - let access_list: AccessList = match access_list { Ok(list) => serde_json::from_value(list.0).unwrap(), Err(_) => { @@ -96,8 +89,6 @@ impl GitRepository { } }; - info!("Access list valid"); - access_list .0 .iter() @@ -115,7 +106,7 @@ impl GitRepository { ) -> Result { match git2::Repository::open(format!( "{}/{}/{}/{}", - repository_directory, self.owner_user.instance.url, self.owner_user.username, self.name + repository_directory, self.owner_user.instance, self.owner_user.username, self.name )) { Ok(repository) => Ok(repository), Err(err) => { @@ -214,7 +205,7 @@ impl GitBackend { ) -> Result { if let Err(err) = std::fs::remove_dir_all(PathBuf::from(format!( "{}/{}/{}/{}", - self.repository_folder, user.instance.url, user.username, repository_name + self.repository_folder, user.instance, user.username, repository_name ))) { let err = GitBackendError::CouldNotDeleteFromDisk(err); error!( @@ -245,11 +236,6 @@ impl GitBackend { name: &str, requester: &Option, ) -> Result { - info!( - "Checking permissions for user {:?} on {}/{}", - requester, owner, name - ); - let repository = match self .find_by_owner_user_name( // &request.owner.instance.url, @@ -389,15 +375,12 @@ impl RepositoryBackend for GitBackend { // Create bare (server side) repository on disk match git2::Repository::init_bare(PathBuf::from(format!( "{}/{}/{}/{}", - self.repository_folder, - request.owner.instance.url, - request.owner.username, - request.name + self.repository_folder, request.owner.instance, request.owner.username, request.name ))) { Ok(_) => { debug!( "Created new repository with the name {}/{}/{}", - request.owner.instance.url, request.owner.username, request.name + request.owner.instance, request.owner.username, request.name ); let repository = Repository { diff --git a/giterated-daemon/src/connection/wrapper.rs b/giterated-daemon/src/connection/wrapper.rs index e84072a..f0402f3 100644 --- a/giterated-daemon/src/connection/wrapper.rs +++ b/giterated-daemon/src/connection/wrapper.rs @@ -9,16 +9,10 @@ use futures_util::{SinkExt, StreamExt}; use giterated_models::{ authenticated::{AuthenticationSource, UserTokenMetadata}, - error::OperationError, instance::Instance, }; -use giterated_models::object_backend::ObjectBackend; - -use giterated_models::{ - authenticated::AuthenticatedPayload, message::GiteratedMessage, object::AnyObject, - operation::AnyOperation, -}; +use giterated_models::authenticated::AuthenticatedPayload; use giterated_stack::{ AuthenticatedInstance, AuthenticatedUser, GiteratedStack, StackOperationState, }; @@ -129,7 +123,7 @@ pub async fn connection_wrapper( verified_instance }; - let user = { + let _user = { let mut verified_user = None; if let Some(verified_instance) = &instance { for source in &message.source { @@ -163,36 +157,14 @@ pub async fn connection_wrapper( verified_user }; - let message: GiteratedMessage = message.into_message(); - - operation_state.user = user; - operation_state.instance = instance; - let result = runtime - .object_operation( - message.object, - &message.operation, - message.payload, - &operation_state, - ) + .handle_network_message(message, &operation_state) .await; // Asking for exploits here operation_state.user = None; operation_state.instance = None; - // Map result to Vec on both - let result = match result { - Ok(result) => Ok(serde_json::to_vec(&result).unwrap()), - Err(err) => Err(match err { - OperationError::Operation(err) => { - OperationError::Operation(serde_json::to_vec(&err).unwrap()) - } - OperationError::Internal(err) => OperationError::Internal(err), - OperationError::Unhandled => OperationError::Unhandled, - }), - }; - let mut socket = connection_state.socket.lock().await; let _ = socket .send(Message::Binary(bincode::serialize(&result).unwrap())) diff --git a/giterated-daemon/src/database_backend/handler.rs b/giterated-daemon/src/database_backend/handler.rs index 483b5a6..193e6aa 100644 --- a/giterated-daemon/src/database_backend/handler.rs +++ b/giterated-daemon/src/database_backend/handler.rs @@ -1,6 +1,6 @@ use std::{error::Error, sync::Arc}; -use futures_util::{future::BoxFuture, FutureExt}; +use futures_util::{future::LocalBoxFuture, FutureExt}; use giterated_models::{ authenticated::UserAuthenticationToken, error::{GetValueError, InstanceError, OperationError, RepositoryError, UserError}, @@ -17,11 +17,9 @@ use giterated_models::{ }, settings::{AnySetting, GetSetting, GetSettingError}, user::{Bio, DisplayName, User, UserRepositoriesRequest}, - value::{AnyValue, GetValue}, -}; -use giterated_stack::{ - AuthenticatedUser, AuthorizedInstance, AuthorizedUser, GiteratedStack, StackOperationState, + value::{AnyValue, GetValueTyped}, }; +use giterated_stack::{AuthenticatedUser, AuthorizedInstance, GiteratedStack, StackOperationState}; use super::DatabaseBackend; @@ -31,7 +29,7 @@ pub fn user_get_repositories( state: DatabaseBackend, _operation_state: StackOperationState, requester: Option, -) -> BoxFuture<'static, Result, OperationError>> { +) -> LocalBoxFuture<'static, Result, OperationError>> { let object = object.clone(); async move { @@ -57,14 +55,14 @@ pub fn user_get_repositories( Ok(repositories) } - .boxed() + .boxed_local() } pub fn user_get_value( object: &User, - operation: GetValue>, + operation: GetValueTyped>, state: DatabaseBackend, -) -> BoxFuture<'static, Result, OperationError>> { +) -> LocalBoxFuture<'static, Result, OperationError>> { let object = object.clone(); async move { @@ -76,14 +74,14 @@ pub fn user_get_value( Ok(value) } - .boxed() + .boxed_local() } pub fn user_get_setting( object: &User, operation: GetSetting, state: DatabaseBackend, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -95,7 +93,7 @@ pub fn user_get_setting( Ok(value) } - .boxed() + .boxed_local() } pub fn repository_info( @@ -105,7 +103,7 @@ pub fn repository_info( operation_state: StackOperationState, backend: Arc, requester: Option, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -113,7 +111,6 @@ pub fn repository_info( .get_object::(&object.to_string(), &operation_state) .await .unwrap(); - let mut repository_backend = state.repository_backend.lock().await; let tree = repository_backend .repository_file_inspect( @@ -149,7 +146,7 @@ pub fn repository_info( Ok(info) } - .boxed() + .boxed_local() } pub fn repository_file_from_id( @@ -160,7 +157,7 @@ pub fn repository_file_from_id( backend: Arc, requester: Option, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -182,7 +179,7 @@ pub fn repository_file_from_id( Ok(file) } - .boxed() + .boxed_local() } pub fn repository_file_from_path( @@ -192,7 +189,7 @@ pub fn repository_file_from_path( operation_state: StackOperationState, backend: Arc, requester: Option, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -217,7 +214,7 @@ pub fn repository_file_from_path( Ok(file) } - .boxed() + .boxed_local() } pub fn repository_diff( @@ -227,7 +224,7 @@ pub fn repository_diff( operation_state: StackOperationState, backend: Arc, requester: Option, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -245,7 +242,7 @@ pub fn repository_diff( Ok(diff) } - .boxed() + .boxed_local() } pub fn repository_diff_patch( @@ -255,7 +252,7 @@ pub fn repository_diff_patch( operation_state: StackOperationState, backend: Arc, requester: Option, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -273,7 +270,7 @@ pub fn repository_diff_patch( Ok(patch) } - .boxed() + .boxed_local() } pub fn repository_commit_before( @@ -283,7 +280,7 @@ pub fn repository_commit_before( operation_state: StackOperationState, backend: Arc, requester: Option, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -301,14 +298,14 @@ pub fn repository_commit_before( Ok(file) } - .boxed() + .boxed_local() } pub fn repository_get_value( object: &Repository, - operation: GetValue>, + operation: GetValueTyped>, state: DatabaseBackend, -) -> BoxFuture<'static, Result, OperationError>> { +) -> LocalBoxFuture<'static, Result, OperationError>> { let object = object.clone(); async move { @@ -322,14 +319,14 @@ pub fn repository_get_value( Ok(value) } - .boxed() + .boxed_local() } pub fn repository_get_setting( object: &Repository, operation: GetSetting, state: DatabaseBackend, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -341,7 +338,7 @@ pub fn repository_get_setting( Ok(value) } - .boxed() + .boxed_local() } pub fn instance_authentication_request( @@ -350,7 +347,7 @@ pub fn instance_authentication_request( state: DatabaseBackend, // Authorizes the request for SAME-INSTANCE _authorized_instance: AuthorizedInstance, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { let mut backend = state.user_backend.lock().await; @@ -360,7 +357,7 @@ pub fn instance_authentication_request( .await .map_err(|e| OperationError::Internal(e.to_string())) } - .boxed() + .boxed_local() } pub fn instance_registration_request( @@ -369,7 +366,7 @@ pub fn instance_registration_request( state: DatabaseBackend, // Authorizes the request for SAME-INSTANCE _authorized_instance: AuthorizedInstance, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { async move { let mut backend = state.user_backend.lock().await; @@ -378,7 +375,7 @@ pub fn instance_registration_request( .await .map_err(|e| OperationError::Internal(e.to_string())) } - .boxed() + .boxed_local() } pub fn instance_create_repository_request( @@ -388,7 +385,7 @@ pub fn instance_create_repository_request( requester: AuthenticatedUser, // Authorizes the request for SAME-INSTANCE _authorized_instance: AuthorizedInstance, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { async move { let mut backend = state.repository_backend.lock().await; @@ -397,15 +394,15 @@ pub fn instance_create_repository_request( .await .map_err(|e| OperationError::Internal(e.to_string())) } - .boxed() + .boxed_local() } pub fn user_get_value_display_name( object: &User, - operation: GetValue, + operation: GetValueTyped, state: DatabaseBackend, - _requester: AuthorizedUser, -) -> BoxFuture<'static, Result>> { + // _requester: AuthorizedUser, +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -419,14 +416,14 @@ pub fn user_get_value_display_name( Ok(serde_json::from_value(raw_value.into_inner()) .map_err(|e| OperationError::Internal(e.to_string()))?) } - .boxed() + .boxed_local() } pub fn user_get_value_bio( object: &User, - operation: GetValue, + operation: GetValueTyped, state: DatabaseBackend, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -440,14 +437,14 @@ pub fn user_get_value_bio( Ok(serde_json::from_value(raw_value.into_inner()) .map_err(|e| OperationError::Internal(e.to_string()))?) } - .boxed() + .boxed_local() } pub fn repository_get_value_description( object: &Repository, - operation: GetValue, + operation: GetValueTyped, state: DatabaseBackend, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -461,14 +458,56 @@ pub fn repository_get_value_description( Ok(serde_json::from_value(raw_value.into_inner()) .map_err(|e| OperationError::Internal(e.to_string()))?) } - .boxed() + .boxed_local() } pub fn repository_get_value_visibility( object: &Repository, - operation: GetValue, + operation: GetValueTyped, + state: DatabaseBackend, +) -> LocalBoxFuture<'static, Result>> { + let object = object.clone(); + + async move { + let mut backend = state.repository_backend.lock().await; + + let raw_value = backend + .get_value(&object, &operation.value_name) + .await + .map_err(|e| OperationError::Internal(e.to_string()))?; + + Ok(serde_json::from_value(raw_value.into_inner()) + .map_err(|e| OperationError::Internal(e.to_string()))?) + } + .boxed_local() +} + +pub fn repository_get_default_branch( + object: &Repository, + operation: GetValueTyped, + state: DatabaseBackend, +) -> LocalBoxFuture<'static, Result>> { + let object = object.clone(); + + async move { + let mut backend = state.repository_backend.lock().await; + + let raw_value = backend + .get_value(&object, &operation.value_name) + .await + .map_err(|e| OperationError::Internal(e.to_string()))?; + + Ok(serde_json::from_value(raw_value.into_inner()) + .map_err(|e| OperationError::Internal(e.to_string()))?) + } + .boxed_local() +} + +pub fn repository_get_latest_commit( + object: &Repository, + operation: GetValueTyped, state: DatabaseBackend, -) -> BoxFuture<'static, Result>> { +) -> LocalBoxFuture<'static, Result>> { let object = object.clone(); async move { @@ -482,5 +521,5 @@ pub fn repository_get_value_visibility( Ok(serde_json::from_value(raw_value.into_inner()) .map_err(|e| OperationError::Internal(e.to_string()))?) } - .boxed() + .boxed_local() } diff --git a/giterated-daemon/src/database_backend/mod.rs b/giterated-daemon/src/database_backend/mod.rs index 2eb81c8..c436f77 100644 --- a/giterated-daemon/src/database_backend/mod.rs +++ b/giterated-daemon/src/database_backend/mod.rs @@ -21,14 +21,15 @@ use self::handler::{ instance_authentication_request, instance_create_repository_request, instance_registration_request, repository_commit_before, repository_diff, repository_diff_patch, repository_file_from_id, repository_file_from_path, - repository_get_value_description, repository_get_value_visibility, repository_info, - user_get_value_bio, user_get_value_display_name, + repository_get_default_branch, repository_get_latest_commit, repository_get_value_description, + repository_get_value_visibility, repository_info, user_get_repositories, user_get_value_bio, + user_get_value_display_name, }; #[derive(Clone, Debug)] pub struct Foobackend {} -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl ObjectBackend for Foobackend { async fn object_operation + Debug>( &self, @@ -72,13 +73,13 @@ impl DatabaseBackend { } } - pub fn into_backend(&self) -> SubstackBuilder { - let mut builder = SubstackBuilder::::new(); + pub fn into_substack(self) -> SubstackBuilder { + let mut builder = SubstackBuilder::::new(self); builder - .object::() .object::() - .object::(); + .object::() + .object::(); builder .setting::() @@ -88,12 +89,15 @@ impl DatabaseBackend { .setting::(); builder - .value(user_get_value_display_name) .value(user_get_value_bio) + .value(user_get_value_display_name) .value(repository_get_value_description) - .value(repository_get_value_visibility); + .value(repository_get_value_visibility) + .value(repository_get_default_branch) + .value(repository_get_latest_commit); builder + .operation(user_get_repositories) .operation(repository_info) .operation(repository_file_from_id) .operation(repository_file_from_path) diff --git a/giterated-daemon/src/keys.rs b/giterated-daemon/src/keys.rs index b57d130..260e978 100644 --- a/giterated-daemon/src/keys.rs +++ b/giterated-daemon/src/keys.rs @@ -14,7 +14,7 @@ impl PublicKeyCache { if let Some(key) = self.keys.get(instance) { return Ok(key.clone()); } else { - let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) + let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance)) .await? .text() .await?; diff --git a/giterated-daemon/src/lib.rs b/giterated-daemon/src/lib.rs index de61ad9..6e84069 100644 --- a/giterated-daemon/src/lib.rs +++ b/giterated-daemon/src/lib.rs @@ -10,7 +10,6 @@ pub mod connection; pub mod database_backend; pub mod federation; pub mod keys; -pub mod message; #[macro_use] extern crate tracing; diff --git a/giterated-daemon/src/main.rs b/giterated-daemon/src/main.rs index d6533ca..29eb3a5 100644 --- a/giterated-daemon/src/main.rs +++ b/giterated-daemon/src/main.rs @@ -20,8 +20,10 @@ use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite}, net::{TcpListener, TcpStream}, sync::Mutex, + task::LocalSet, }; use tokio_tungstenite::{accept_async, WebSocketStream}; +use tokio_util::task::LocalPoolHandle; use toml::Table; #[macro_use] @@ -91,7 +93,7 @@ async fn main() -> Result<(), Error> { let mut runtime = GiteratedStack::default(); - let database_backend = database_backend.into_backend(); + let database_backend = database_backend.into_substack(); runtime.merge_builder(database_backend); let runtime = Arc::new(runtime); @@ -106,6 +108,8 @@ async fn main() -> Result<(), Error> { } }; + let pool = LocalPoolHandle::new(5); + loop { let stream = accept_stream(&mut listener).await; info!("Connected"); @@ -134,22 +138,35 @@ async fn main() -> Result<(), Error> { }; info!("Websocket connection established with {}", address); - - let connection = RawConnection { - task: tokio::spawn(connection_wrapper( + let connections_cloned = connections.clone(); + let repository_backend = repository_backend.clone(); + let user_backend = user_backend.clone(); + let token_granter = token_granter.clone(); + let settings = settings.clone(); + let instance_connections = instance_connections.clone(); + let config = config.clone(); + let runtime = runtime.clone(); + let operation_state = operation_state.clone(); + + pool.spawn_pinned(move || { + connection_wrapper( connection, - connections.clone(), - repository_backend.clone(), - user_backend.clone(), - token_granter.clone(), - settings.clone(), + connections_cloned, + repository_backend, + user_backend, + token_granter, + settings, address, Instance::from_str(config["giterated"]["instance"].as_str().unwrap()).unwrap(), - instance_connections.clone(), - config.clone(), - runtime.clone(), - operation_state.clone(), - )), + instance_connections, + config, + runtime, + operation_state, + ) + }); + + let connection = RawConnection { + task: tokio::spawn(async move { () }), }; connections.lock().await.connections.push(connection); diff --git a/giterated-daemon/src/message.rs b/giterated-daemon/src/message.rs deleted file mode 100644 index 25df5a2..0000000 --- a/giterated-daemon/src/message.rs +++ /dev/null @@ -1,259 +0,0 @@ -use anyhow::Error; -use futures_util::Future; - -use giterated_models::instance::Instance; - -use giterated_models::user::User; - -use giterated_models::authenticated::{ - AuthenticatedPayload, AuthenticationSource, UserTokenMetadata, -}; -use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; -use rsa::{ - pkcs1::DecodeRsaPublicKey, - pss::{Signature, VerifyingKey}, - sha2::Sha256, - signature::Verifier, - RsaPublicKey, -}; -use serde::{de::DeserializeOwned, Serialize}; -use std::{fmt::Debug, ops::Deref}; - -use crate::connection::wrapper::ConnectionState; - -pub struct NetworkMessage(pub Vec); - -impl Deref for NetworkMessage { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -pub struct AuthenticatedUser(pub User); - -#[derive(Debug, thiserror::Error)] -pub enum UserAuthenticationError { - #[error("user authentication missing")] - Missing, - // #[error("{0}")] - // InstanceAuthentication(#[from] Error), - #[error("user token was invalid")] - InvalidToken, - #[error("an error has occured")] - Other(#[from] Error), -} - -pub struct AuthenticatedInstance(Instance); - -impl AuthenticatedInstance { - pub fn inner(&self) -> &Instance { - &self.0 - } -} - -#[async_trait::async_trait] -pub trait FromMessage: Sized + Send + Sync { - async fn from_message(message: &NetworkMessage, state: &S) -> Result; -} - -#[async_trait::async_trait] -impl FromMessage for AuthenticatedUser { - async fn from_message( - network_message: &NetworkMessage, - state: &ConnectionState, - ) -> Result { - let message: AuthenticatedPayload = - serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; - - let (auth_user, auth_token) = message - .source - .iter() - .filter_map(|auth| { - if let AuthenticationSource::User { user, token } = auth { - Some((user, token)) - } else { - None - } - }) - .next() - .ok_or_else(|| UserAuthenticationError::Missing)?; - - let authenticated_instance = - AuthenticatedInstance::from_message(network_message, state).await?; - - let public_key_raw = state.public_key(&auth_user.instance).await?; - let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap(); - - let data: TokenData = decode( - auth_token.as_ref(), - &verification_key, - &Validation::new(Algorithm::RS256), - ) - .unwrap(); - - if data.claims.user != *auth_user - || data.claims.generated_for != *authenticated_instance.inner() - { - Err(Error::from(UserAuthenticationError::InvalidToken)) - } else { - Ok(AuthenticatedUser(data.claims.user)) - } - } -} - -#[async_trait::async_trait] -impl FromMessage for AuthenticatedInstance { - async fn from_message( - network_message: &NetworkMessage, - state: &ConnectionState, - ) -> Result { - let message: AuthenticatedPayload = - serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; - - let (instance, signature) = message - .source - .iter() - .filter_map(|auth: &AuthenticationSource| { - if let AuthenticationSource::Instance { - instance, - signature, - } = auth - { - Some((instance, signature)) - } else { - None - } - }) - .next() - // TODO: Instance authentication error - .ok_or_else(|| UserAuthenticationError::Missing)?; - - let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap(); - - let verifying_key: VerifyingKey = VerifyingKey::new(public_key); - - verifying_key.verify( - &message.payload, - &Signature::try_from(signature.as_ref()).unwrap(), - )?; - - Ok(AuthenticatedInstance(instance.clone())) - } -} - -#[async_trait::async_trait] -impl FromMessage for Option -where - T: FromMessage, - S: Send + Sync + 'static, -{ - async fn from_message(message: &NetworkMessage, state: &S) -> Result { - Ok(T::from_message(message, state).await.ok()) - } -} - -#[async_trait::async_trait] -pub trait MessageHandler { - async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result; -} -#[async_trait::async_trait] -impl MessageHandler<(T1,), S, R> for T -where - T: FnOnce(T1) -> F + Clone + Send + 'static, - F: Future> + Send, - T1: FromMessage + Send, - S: Send + Sync, - E: std::error::Error + Send + Sync + 'static, -{ - async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { - let value = T1::from_message(message, state).await?; - self(value).await.map_err(|e| Error::from(e)) - } -} - -#[async_trait::async_trait] -impl MessageHandler<(T1, T2), S, R> for T -where - T: FnOnce(T1, T2) -> F + Clone + Send + 'static, - F: Future> + Send, - T1: FromMessage + Send, - T2: FromMessage + Send, - S: Send + Sync, - E: std::error::Error + Send + Sync + 'static, -{ - async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { - let value = T1::from_message(message, state).await?; - let value_2 = T2::from_message(message, state).await?; - self(value, value_2).await.map_err(|e| Error::from(e)) - } -} - -#[async_trait::async_trait] -impl MessageHandler<(T1, T2, T3), S, R> for T -where - T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static, - F: Future> + Send, - T1: FromMessage + Send, - T2: FromMessage + Send, - T3: FromMessage + Send, - S: Send + Sync, - E: std::error::Error + Send + Sync + 'static, -{ - async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { - let value = T1::from_message(message, state).await?; - let value_2 = T2::from_message(message, state).await?; - let value_3 = T3::from_message(message, state).await?; - - self(value, value_2, value_3) - .await - .map_err(|e| Error::from(e)) - } -} - -pub struct State(pub T); - -#[async_trait::async_trait] -impl FromMessage for State -where - T: Clone + Send + Sync, -{ - async fn from_message(_: &NetworkMessage, state: &T) -> Result { - Ok(Self(state.clone())) - } -} - -// Temp -#[async_trait::async_trait] -impl FromMessage for Message -where - T: DeserializeOwned + Send + Sync + Serialize + Debug, - S: Clone + Send + Sync, -{ - async fn from_message(message: &NetworkMessage, _: &S) -> Result { - let payload: AuthenticatedPayload = serde_json::from_slice(&message)?; - let payload = bincode::deserialize(&payload.payload)?; - - Ok(Message(payload)) - } -} - -pub struct Message(pub T); - -/// Handshake-specific message type. -/// -/// Uses basic serde_json-based deserialization to maintain the highest -/// level of compatibility across versions. -pub struct HandshakeMessage(pub T); - -#[async_trait::async_trait] -impl FromMessage for HandshakeMessage -where - T: DeserializeOwned + Send + Sync + Serialize, - S: Clone + Send + Sync, -{ - async fn from_message(message: &NetworkMessage, _: &S) -> Result { - Ok(HandshakeMessage(serde_json::from_slice(&message.0)?)) - } -} diff --git a/giterated-models/Cargo.toml b/giterated-models/Cargo.toml index 1179cea..24bfa48 100644 --- a/giterated-models/Cargo.toml +++ b/giterated-models/Cargo.toml @@ -24,6 +24,7 @@ git2 = "0.17" chrono = { version = "0.4", features = [ "serde" ] } async-trait = "0.1" serde_with = "3.3.0" +url = {version = "2.4.1", features = ["serde"]} # Git backend sqlx = { version = "0.7", default-features = false, features = [ "macros", "chrono" ] } diff --git a/giterated-models/src/authenticated.rs b/giterated-models/src/authenticated.rs index 8989d05..eb4760a 100644 --- a/giterated-models/src/authenticated.rs +++ b/giterated-models/src/authenticated.rs @@ -8,13 +8,12 @@ use rsa::{ RsaPrivateKey, }; use serde::{Deserialize, Serialize}; -use serde_json::Value; use crate::{ instance::Instance, message::GiteratedMessage, object::{AnyObject, GiteratedObject}, - operation::{AnyOperation, AnyOperationV2, GiteratedOperation}, + operation::{AnyOperation, GiteratedOperation}, user::User, }; @@ -41,19 +40,10 @@ pub struct AuthenticatedPayload { impl AuthenticatedPayload { pub fn into_message(self) -> GiteratedMessage { - let payload = serde_json::from_slice::(&self.payload).unwrap(); GiteratedMessage { object: AnyObject(self.object), operation: self.operation, - payload: AnyOperation(payload), - } - } - pub fn into_message_v2(self) -> GiteratedMessage { - let _payload = serde_json::from_slice::(&self.payload).unwrap(); - GiteratedMessage { - object: AnyObject(self.object), - operation: self.operation, - payload: AnyOperationV2(serde_json::to_vec(&self.payload).unwrap()), + payload: AnyOperation(self.payload), } } } diff --git a/giterated-models/src/instance/mod.rs b/giterated-models/src/instance/mod.rs index c907ce1..3aebd12 100644 --- a/giterated-models/src/instance/mod.rs +++ b/giterated-models/src/instance/mod.rs @@ -7,6 +7,7 @@ mod operations; mod values; pub use operations::*; +use url::Url; pub use values::*; use crate::object::GiteratedObject; @@ -34,9 +35,7 @@ pub struct InstanceMeta { /// assert_eq!(Instance::from_str("giterated.dev").unwrap(), instance); /// ``` #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct Instance { - pub url: String, -} +pub struct Instance(pub String); impl GiteratedObject for Instance { fn object_name() -> &'static str { @@ -50,7 +49,7 @@ impl GiteratedObject for Instance { impl Display for Instance { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.url) + f.write_str(&self.0.to_string()) } } @@ -58,7 +57,13 @@ impl FromStr for Instance { type Err = InstanceParseError; fn from_str(s: &str) -> Result { - Ok(Self { url: s.to_string() }) + let with_protocol = format!("wss://{}", s); + + if Url::parse(&with_protocol).is_ok() { + Ok(Self(s.to_string())) + } else { + Err(InstanceParseError::InvalidFormat) + } } } diff --git a/giterated-models/src/message.rs b/giterated-models/src/message.rs index cb04778..ffd4b09 100644 --- a/giterated-models/src/message.rs +++ b/giterated-models/src/message.rs @@ -47,7 +47,7 @@ impl GiteratedMessage { &self, ) -> Result, ()> { let object = O::from_object_str(&self.object.0).map_err(|_| ())?; - let payload = serde_json::from_value::(self.payload.0.clone()).map_err(|_| ())?; + let payload = serde_json::from_slice::(&self.payload.0).map_err(|_| ())?; Ok(GiteratedMessage { object, diff --git a/giterated-models/src/object.rs b/giterated-models/src/object.rs index 8c9a7b3..d419efc 100644 --- a/giterated-models/src/object.rs +++ b/giterated-models/src/object.rs @@ -11,7 +11,7 @@ use crate::{ object_backend::ObjectBackend, operation::GiteratedOperation, settings::{AnySetting, GetSetting, GetSettingError, SetSetting, SetSettingError, Setting}, - value::{GetValue, GiteratedObjectValue}, + value::{GetValueTyped, GiteratedObjectValue}, }; mod operations; @@ -22,7 +22,7 @@ pub struct Object< 'b, S: Clone + Send + Sync, O: GiteratedObject, - B: ObjectBackend + 'b + Send + Sync + Clone, + B: ObjectBackend + 'b + Send + Clone, > { pub(crate) inner: O, pub(crate) backend: B, @@ -56,7 +56,7 @@ impl< } } -pub trait GiteratedObject: Send + Display + FromStr + Sync { +pub trait GiteratedObject: Send + Display + FromStr + Sync + Clone { fn object_name() -> &'static str; fn from_object_str(object_str: &str) -> Result; @@ -69,21 +69,21 @@ impl< B: ObjectBackend, > Object<'b, I, O, B> { - pub async fn get + Send + Debug>( + pub async fn get + Send + Debug + 'static>( &mut self, operation_state: &I, ) -> Result> { let result = self .request( - GetValue { + GetValueTyped:: { value_name: V::value_name().to_string(), - _marker: PhantomData, + ty: Default::default(), }, operation_state, ) .await; - result + Ok(result?) } pub async fn get_setting( @@ -115,7 +115,7 @@ impl< .await } - pub async fn request + Debug>( + pub async fn request + Debug + 'static>( &mut self, request: R, operation_state: &I, diff --git a/giterated-models/src/object_backend.rs b/giterated-models/src/object_backend.rs index 31df970..d231a2e 100644 --- a/giterated-models/src/object_backend.rs +++ b/giterated-models/src/object_backend.rs @@ -6,8 +6,8 @@ use crate::{ use std::fmt::Debug; -#[async_trait::async_trait] -pub trait ObjectBackend: Send + Sync + Sized + Clone { +#[async_trait::async_trait(?Send)] +pub trait ObjectBackend: Sized + Clone + Send { async fn object_operation( &self, object: O, @@ -17,9 +17,9 @@ pub trait ObjectBackend: Send + Sync + Sized + Clone { ) -> Result> where O: GiteratedObject + Debug + 'static, - D: GiteratedOperation + Debug; + D: GiteratedOperation + Debug + 'static; - async fn get_object( + async fn get_object( &self, object_str: &str, operation_state: &S, diff --git a/giterated-models/src/operation.rs b/giterated-models/src/operation.rs index 3cbb488..ab1f443 100644 --- a/giterated-models/src/operation.rs +++ b/giterated-models/src/operation.rs @@ -1,7 +1,6 @@ use std::{any::type_name, fmt::Debug}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Value; use crate::object::GiteratedObject; @@ -19,20 +18,9 @@ pub trait GiteratedOperation: #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(transparent)] #[repr(transparent)] -pub struct AnyOperation(pub Value); +pub struct AnyOperation(pub Vec); impl GiteratedOperation for AnyOperation { - type Success = Value; - - type Failure = Value; -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(transparent)] -#[repr(transparent)] -pub struct AnyOperationV2(pub Vec); - -impl GiteratedOperation for AnyOperationV2 { type Success = Vec; type Failure = Vec; diff --git a/giterated-models/src/user/mod.rs b/giterated-models/src/user/mod.rs index 23e7680..1cf1a06 100644 --- a/giterated-models/src/user/mod.rs +++ b/giterated-models/src/user/mod.rs @@ -53,7 +53,7 @@ impl GiteratedObject for User { impl Display for User { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}:{}", self.username, self.instance.url) + write!(f, "{}:{}", self.username, self.instance.0) } } diff --git a/giterated-models/src/value.rs b/giterated-models/src/value.rs index 1efe09d..f501620 100644 --- a/giterated-models/src/value.rs +++ b/giterated-models/src/value.rs @@ -12,18 +12,26 @@ pub trait GiteratedObjectValue: Send + Sync + Serialize + DeserializeOwned { } #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GetValue { +pub struct GetValue { pub value_name: String, - pub(crate) _marker: PhantomData, +} + +impl GiteratedOperation for GetValue { + fn operation_name() -> &'static str { + "get_value" + } + type Success = Value; + type Failure = GetValueError; } #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GetValueV2 { +pub struct GetValueTyped { pub value_name: String, + pub ty: PhantomData, } -impl + Send> GiteratedOperation - for GetValue +impl> GiteratedOperation + for GetValueTyped { fn operation_name() -> &'static str { "get_value" diff --git a/giterated-stack/src/handler.rs b/giterated-stack/src/handler.rs index 901d352..30d84cb 100644 --- a/giterated-stack/src/handler.rs +++ b/giterated-stack/src/handler.rs @@ -1,14 +1,18 @@ use std::{any::Any, collections::HashMap, sync::Arc}; +use futures_util::FutureExt; use giterated_models::{ authenticated::AuthenticatedPayload, - error::OperationError, + error::{GetValueError, OperationError}, + instance::Instance, message::GiteratedMessage, - object::{AnyObject, GiteratedObject, Object, ObjectRequestError}, + object::{ + AnyObject, GiteratedObject, Object, ObjectRequest, ObjectRequestError, ObjectResponse, + }, object_backend::ObjectBackend, - operation::{AnyOperationV2, GiteratedOperation}, + operation::{AnyOperation, GiteratedOperation}, settings::{GetSetting, SetSetting, Setting}, - value::{GetValue, GetValueV2, GiteratedObjectValue}, + value::{AnyValue, GetValue, GetValueTyped, GiteratedObjectValue}, }; use tracing::trace; @@ -77,13 +81,28 @@ impl HandlerTree { self.elements.push(handler); } - pub fn handle( + pub async fn handle( &self, - _object: &dyn Any, - _operation: Box, - _operation_state: &StackOperationState, - ) -> Result, OperationError>> { - todo!() + object: &Box, + operation: &Box, + operation_state: &StackOperationState, + ) -> Result, OperationError>> { + for handler in self.elements.iter() { + match handler.handle(object, &operation, operation_state).await { + Ok(success) => return Ok(success), + Err(err) => match err { + OperationError::Operation(failure) => { + return Err(OperationError::Operation(failure)) + } + OperationError::Internal(e) => return Err(OperationError::Internal(e)), + _ => { + continue; + } + }, + } + } + + Err(OperationError::Unhandled) } } @@ -113,8 +132,14 @@ pub struct SubstackBuilder { } impl SubstackBuilder { - pub fn new() -> Self { - todo!() + pub fn new(state: S) -> Self { + Self { + operation_handlers: Default::default(), + value_getters: Default::default(), + setting_getters: Default::default(), + metadata: Default::default(), + state, + } } } @@ -127,8 +152,10 @@ impl SubstackBuilder { pub fn operation(&mut self, handler: H) -> &mut Self where O: GiteratedObject + 'static, - D: GiteratedOperation + 'static, - H: GiteratedOperationHandler + 'static + Clone, + D: GiteratedOperation + 'static + Clone, + H: GiteratedOperationHandler + 'static + Clone + Send + Sync, + D::Failure: Send + Sync, + D::Success: Send + Sync, { let object_name = handler.object_name().to_string(); let operation_name = handler.operation_name().to_string(); @@ -140,7 +167,7 @@ impl SubstackBuilder { operation_name, }; - assert!(self.operation_handlers.insert(pair, wrapped).is_none()); + self.operation_handlers.insert(pair, wrapped); self.metadata.register_operation::(); @@ -154,6 +181,26 @@ impl SubstackBuilder { pub fn object(&mut self) -> &mut Self { self.metadata.register_object::(); + // Insert handler so ObjectRequest is handled properly + let handler = move |_object: &Instance, + operation: ObjectRequest, + _state: S, + _operation_state: StackOperationState, + stack: Arc| { + async move { + for (_object_name, object_meta) in stack.metadata.objects.iter() { + if (object_meta.from_str)(&operation.0).is_ok() { + return Ok(ObjectResponse(operation.0)); + } + } + + Err(OperationError::Unhandled) + } + .boxed_local() + }; + + self.operation(handler); + self } @@ -176,14 +223,61 @@ impl SubstackBuilder { pub fn value(&mut self, handler: F) -> &mut Self where O: GiteratedObject + 'static, - V: GiteratedObjectValue + 'static, - F: GiteratedOperationHandler, S> + Clone + 'static, + V: GiteratedObjectValue + 'static + Clone, + F: GiteratedOperationHandler, S> + Clone + 'static + Send + Sync, { let object_name = handler.object_name().to_string(); let value_name = V::value_name().to_string(); let wrapped = OperationWrapper::new(handler, self.state.clone()); + let handler_object_name = object_name.clone(); + let handler_value_name = value_name.clone(); + + // Insert handler so GetValue is handled properly + let _handler = move |object: &O, + operation: GetValueTyped>, + _state: S, + operation_state: StackOperationState, + stack: Arc| { + let stack = stack.clone(); + let object_name = handler_object_name; + let value_name = handler_value_name; + let object = object.clone(); + async move { + for (target, getter) in stack.value_getters.iter() { + if target.object_kind != object_name { + continue; + } + + if target.value_kind != value_name { + continue; + } + + return match getter + .handle( + &(Box::new(object.clone()) as Box), + &(Box::new(GetValueTyped:: { + value_name: operation.value_name, + ty: Default::default(), + }) as Box), + &operation_state, + ) + .await { + Ok(success) => Ok(*success.downcast::< as GiteratedOperation>::Success>().unwrap()), + Err(err) => Err(match err { + OperationError::Operation(failure) => OperationError::Operation(*failure.downcast::< as GiteratedOperation>::Failure>().unwrap()), + OperationError::Internal(internal) => OperationError::Internal(internal), + OperationError::Unhandled => OperationError::Unhandled, + }), + } + } + + Err(OperationError::Unhandled) + } + .boxed_local() + }; + assert!(self .value_getters .insert( @@ -204,7 +298,7 @@ impl SubstackBuilder { pub fn object_settings(&mut self, handler: F) -> &mut Self where O: GiteratedObject + 'static, - F: GiteratedOperationHandler + Clone + 'static, + F: GiteratedOperationHandler + Clone + 'static + Send + Sync, { let object_name = handler.object_name().to_string(); @@ -253,7 +347,7 @@ impl RuntimeMetadata { name: operation_name, object_kind: object_name, deserialize: Box::new(|bytes| { - Ok(Box::new(serde_json::from_slice::(bytes).unwrap()) + Ok(Box::new(serde_json::from_slice::(bytes)?) as Box) }), any_is_same: Box::new(|any_box| any_box.is::()), @@ -291,6 +385,7 @@ impl RuntimeMetadata { ) { let object_name = O::object_name().to_string(); let value_name = V::value_name().to_string(); + let value_name_for_get = V::value_name().to_string(); if self .values @@ -300,8 +395,20 @@ impl RuntimeMetadata { value_kind: value_name.clone(), }, ValueMeta { - name: value_name, + name: value_name.clone(), deserialize: Box::new(|bytes| Ok(Box::new(serde_json::from_slice(&bytes)?))), + serialize: Box::new(|value| { + let value = value.downcast::().unwrap(); + + Ok(serde_json::to_vec(&*value)?) + }), + typed_get: Box::new(move || { + Box::new(GetValueTyped:: { + value_name: value_name_for_get.clone(), + ty: Default::default(), + }) + }), + is_get_value_typed: Box::new(move |typed| typed.is::>()), }, ) .is_some() @@ -350,32 +457,201 @@ impl RuntimeMetadata { self.settings.extend(other.settings); } } +impl GiteratedStack { + /// Handles a giterated network message, returning either a raw success + /// payload or a serialized error payload. + pub async fn handle_network_message( + &self, + message: AuthenticatedPayload, + operation_state: &StackOperationState, + ) -> Result, OperationError>> { + let message: GiteratedMessage = message.into_message(); + + // Deserialize the object, also getting the object type's name + let (object_type, object) = { + let mut result = None; + + for (object_type, object_meta) in self.metadata.objects.iter() { + if let Ok(object) = (object_meta.from_str)(&message.object.0) { + result = Some((object_type.clone(), object)); + break; + } + } + + result + } + .ok_or_else(|| OperationError::Unhandled)?; + + trace!( + "Handling network message {}::<{}>", + message.operation, + object_type + ); + + if message.operation == "get_value" { + // Special case + let operation: GetValue = serde_json::from_slice(&message.payload.0).unwrap(); + + return self + .network_get_value(object, object_type.clone(), operation, operation_state) + .await; + } + + let target = ObjectOperationPair { + object_name: object_type.clone(), + operation_name: message.operation.clone(), + }; + + // Resolve the target operations from the handlers table + let handler = self + .operation_handlers + .get(&target) + .ok_or_else(|| OperationError::Unhandled)?; + + trace!( + "Resolved operation handler for network message {}::<{}>", + message.operation, + object_type + ); + + // Deserialize the operation + let meta = self + .metadata + .operations + .get(&target) + .ok_or_else(|| OperationError::Unhandled)?; + + let operation = (meta.deserialize)(&message.payload.0) + .map_err(|e| OperationError::Internal(e.to_string()))?; + + trace!( + "Deserialized operation for network message {}::<{}>", + message.operation, + object_type + ); + + trace!( + "Calling handler for network message {}::<{}>", + message.operation, + object_type + ); + + // Get the raw result of the operation, where the return values are boxed. + let raw_result = handler.handle(&object, &operation, operation_state).await; + + trace!( + "Finished handling network message {}::<{}>", + message.operation, + object_type + ); + + // Deserialize the raw result for the network + match raw_result { + Ok(success) => Ok((meta.serialize_success)(success) + .map_err(|e| OperationError::Internal(e.to_string()))?), + Err(err) => Err(match err { + OperationError::Operation(failure) => OperationError::Operation( + (meta.serialize_error)(failure) + .map_err(|e| OperationError::Internal(e.to_string()))?, + ), + OperationError::Internal(internal) => OperationError::Internal(internal), + OperationError::Unhandled => OperationError::Unhandled, + }), + } + } + + pub async fn network_get_value( + &self, + object: Box, + object_kind: String, + operation: GetValue, + operation_state: &StackOperationState, + ) -> Result, OperationError>> { + trace!("Handling network get_value for {}", operation.value_name); + + let value_meta = self + .metadata + .values + .get(&ObjectValuePair { + object_kind: object_kind.clone(), + value_kind: operation.value_name.clone(), + }) + .ok_or_else(|| OperationError::Unhandled)?; + + for (target, getter) in self.value_getters.iter() { + if target.object_kind != object_kind { + continue; + } + + if target.value_kind != operation.value_name { + continue; + } -#[async_trait::async_trait] -impl GiteratedOperationHandler for GiteratedStack -where - O: GiteratedObject + 'static, - D: GiteratedOperation + 'static, - S: GiteratedStackState + 'static, -{ - fn operation_name(&self) -> &str { - D::operation_name() + return match getter + .handle(&(object), &((value_meta.typed_get)()), &operation_state) + .await + { + Ok(success) => { + // Serialize success, which is the value type itself + let serialized = (value_meta.serialize)(success) + .map_err(|e| OperationError::Internal(e.to_string()))?; + + Ok(serialized) + } + Err(err) => Err(match err { + OperationError::Operation(failure) => { + // Failure is sourced from GetValue operation, but this is hardcoded for now + let failure: GetValueError = *failure.downcast().unwrap(); + + OperationError::Operation( + serde_json::to_vec(&failure) + .map_err(|e| OperationError::Internal(e.to_string()))?, + ) + } + OperationError::Internal(internal) => OperationError::Internal(internal), + OperationError::Unhandled => OperationError::Unhandled, + }), + }; + } + + Err(OperationError::Unhandled) + } + + pub async fn network_get_setting( + &self, + _operation: GetSetting, + _operation_state: &StackOperationState, + ) -> Result, OperationError>> { + todo!() } - fn object_name(&self) -> &str { - O::object_name() + pub async fn network_set_setting( + &self, + _operation: SetSetting, + _operation_state: &StackOperationState, + ) -> Result, OperationError>> { + todo!() } +} + +use core::fmt::Debug; - async fn handle( +#[async_trait::async_trait(?Send)] +impl ObjectBackend for Arc { + async fn object_operation( &self, - object: &O, - operation: D, - _state: S, + in_object: O, + operation_name: &str, + payload: D, operation_state: &StackOperationState, - ) -> Result> { + ) -> Result> + where + O: GiteratedObject + Debug + 'static, + D: GiteratedOperation + Debug + 'static, + { // Erase object and operation types. - let object = object as &dyn Any; - let operation = Box::new(operation) as Box; + let object = Box::new(in_object.clone()) as Box; + let operation = Box::new(payload) as Box; // We need to determine the type of the object, iterate through all known // object types and check if the &dyn Any we have is the same type as the @@ -384,7 +660,7 @@ where let mut object_type = None; for (object_name, object_meta) in self.metadata.objects.iter() { - if (object_meta.any_is_same)(object) { + if (object_meta.any_is_same)(&in_object) { object_type = Some(object_name.clone()); break; } @@ -395,12 +671,79 @@ where .ok_or_else(|| OperationError::Unhandled)?; // We need to hijack get_value, set_setting, and get_setting. - if operation.is::() { - todo!() + if operation_name == "get_value" { + let mut value_meta = None; + for (_, meta) in self.metadata.values.iter() { + if (meta.is_get_value_typed)(&operation) { + value_meta = Some(meta); + break; + } + } + + let value_meta = value_meta.ok_or_else(|| OperationError::Unhandled)?; + + let value_name = value_meta.name.clone(); + + trace!("Handling get_value for {}::{}", object_type, value_name); + + for (target, getter) in self.value_getters.iter() { + if target.object_kind != object_type { + continue; + } + + if target.value_kind != value_name { + continue; + } + + return match getter + .handle(&(object), &((value_meta.typed_get)()), &operation_state) + .await + { + Ok(success) => Ok(*success.downcast().unwrap()), + Err(err) => Err(match err { + OperationError::Operation(failure) => { + OperationError::Operation(*failure.downcast::().unwrap()) + } + OperationError::Internal(internal) => OperationError::Internal(internal), + OperationError::Unhandled => OperationError::Unhandled, + }), + }; + } + + return Err(OperationError::Unhandled); } else if operation.is::() { - todo!() + let get_setting: Box = operation.downcast().unwrap(); + let setting_name = get_setting.setting_name.clone(); + + // Get the setting getter for the object type + let getter = self + .setting_getters + .get(&object_type) + .ok_or_else(|| OperationError::Unhandled)?; + + let setting = getter + .handle( + &(Box::new(in_object.clone()) as _), + &(Box::new(get_setting) as _), + operation_state, + ) + .await + .map_err(|_e| OperationError::Unhandled)?; + + let _setting_meta = self + .metadata + .settings + .get(&setting_name) + .ok_or_else(|| OperationError::Unhandled)?; + + let setting_success: >::Success = + *setting.downcast().unwrap(); + + return Ok(setting_success); } else if operation.is::() { todo!() + } else if operation.is::() { + todo!() } // Resolve the operation from the known operations table. @@ -413,6 +756,10 @@ where continue; } + if target.operation_name != operation_name { + continue; + } + if (operation_meta.any_is_same)(&operation) { operation_type = Some(target.clone()); break; @@ -429,7 +776,9 @@ where .get(&operation_type) .ok_or_else(|| OperationError::Unhandled)?; - let raw_result = handler_tree.handle(object, operation, operation_state); + let raw_result = handler_tree + .handle(&object, &operation, operation_state) + .await; // Convert the dynamic result back into its concrete type match raw_result { @@ -443,97 +792,21 @@ where }), } } -} -impl GiteratedStack { - /// Handles a giterated network message, returning either a raw success - /// payload or a serialized error payload. - pub async fn handle_network_message( + async fn get_object( &self, - message: AuthenticatedPayload, - _state: &S, - operation_state: &StackOperationState, - ) -> Result, OperationError>> { - let message: GiteratedMessage = message.into_message_v2(); - - // Deserialize the object, also getting the object type's name - let (object_type, object) = { - let mut result = None; - - for (object_type, object_meta) in self.metadata.objects.iter() { - if let Ok(object) = (object_meta.from_str)(&message.object.0) { - result = Some((object_type.clone(), object)); - break; - } + object_str: &str, + _operation_state: &StackOperationState, + ) -> Result, OperationError> { + // TODO: Authorization? + for (_object_name, object_meta) in self.metadata.objects.iter() { + if let Ok(object) = (object_meta.from_str)(object_str) { + return Ok(unsafe { + Object::new_unchecked(*object.downcast::().unwrap(), self.clone()) + }); } - - result - } - .ok_or_else(|| OperationError::Unhandled)?; - - let target = ObjectOperationPair { - object_name: object_type, - operation_name: message.operation, - }; - - // Resolve the target operations from the handlers table - let handler = self - .operation_handlers - .get(&target) - .ok_or_else(|| OperationError::Unhandled)?; - - // Deserialize the operation - let meta = self - .metadata - .operations - .get(&target) - .ok_or_else(|| OperationError::Unhandled)?; - - let operation = - (meta.deserialize)(&message.payload.0).map_err(|_| OperationError::Unhandled)?; - - // Get the raw result of the operation, where the return values are boxed. - let raw_result = handler.handle(&object, operation, operation_state); - - // Deserialize the raw result for the network - match raw_result { - Ok(success) => Ok((meta.serialize_success)(success) - .map_err(|e| OperationError::Internal(e.to_string()))?), - Err(err) => Err(match err { - OperationError::Operation(failure) => OperationError::Operation( - (meta.serialize_error)(failure) - .map_err(|e| OperationError::Internal(e.to_string()))?, - ), - OperationError::Internal(internal) => OperationError::Internal(internal), - OperationError::Unhandled => OperationError::Unhandled, - }), } - } -} - -use core::fmt::Debug; - -#[async_trait::async_trait] -impl ObjectBackend for Arc { - async fn object_operation( - &self, - _object: O, - _operation: &str, - _payload: D, - _operation_state: &StackOperationState, - ) -> Result> - where - O: GiteratedObject + Debug + 'static, - D: GiteratedOperation + Debug, - { - todo!() - } - async fn get_object( - &self, - _object_str: &str, - _operation_state: &StackOperationState, - ) -> Result, OperationError> { - todo!() + Err(OperationError::Unhandled) } } diff --git a/giterated-stack/src/lib.rs b/giterated-stack/src/lib.rs index 1773042..0602016 100644 --- a/giterated-stack/src/lib.rs +++ b/giterated-stack/src/lib.rs @@ -19,40 +19,46 @@ use giterated_models::{ repository::{AccessList, Repository}, settings::{GetSetting, SetSetting}, user::User, - value::{GetValue, GiteratedObjectValue}, + value::GetValue, }; use serde::{de::DeserializeOwned, Serialize}; #[derive(Clone, Debug, Hash, Eq, PartialEq)] struct ObjectOperationPair { - object_name: String, - operation_name: String, + pub object_name: String, + pub operation_name: String, } pub struct SettingMeta { - name: String, - deserialize: Box Result, serde_json::Error> + Send + Sync>, + pub name: String, + pub deserialize: Box Result, serde_json::Error> + Send + Sync>, } pub struct ValueMeta { - name: String, - deserialize: Box Result, serde_json::Error> + Send + Sync>, + pub name: String, + pub deserialize: Box Result, serde_json::Error> + Send + Sync>, + pub serialize: + Box) -> Result, serde_json::Error> + Send + Sync>, + pub typed_get: Box Box + Send + Sync>, + pub is_get_value_typed: Box) -> bool + Send + Sync>, } pub struct ObjectMeta { - name: String, - from_str: Box Result, ()> + Send + Sync>, - any_is_same: Box bool + Send + Sync>, + pub name: String, + pub from_str: Box Result, ()> + Send + Sync>, + pub any_is_same: Box bool + Send + Sync>, } pub struct OperationMeta { - name: String, - object_kind: String, - deserialize: Box Result, ()> + Send + Sync>, - any_is_same: Box bool + Send + Sync>, - serialize_success: + pub name: String, + pub object_kind: String, + pub deserialize: + Box Result, serde_json::Error> + Send + Sync>, + pub any_is_same: Box bool + Send + Sync>, + pub serialize_success: + Box) -> Result, serde_json::Error> + Send + Sync>, + pub serialize_error: Box) -> Result, serde_json::Error> + Send + Sync>, - serialize_error: Box) -> Result, serde_json::Error> + Send + Sync>, } #[derive(Clone, Debug, Hash, Eq, PartialEq)] @@ -61,13 +67,13 @@ pub struct ObjectValuePair { pub value_kind: String, } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] pub trait GiteratedOperationHandler< L, O: GiteratedObject, D: GiteratedOperation, S: Send + Sync + Clone, ->: Send + Sync +> { fn operation_name(&self) -> &str; fn object_name(&self) -> &str; @@ -81,16 +87,16 @@ pub trait GiteratedOperationHandler< ) -> Result>; } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl GiteratedOperationHandler<(), O, D, S> for F where F: FnMut( &O, D, S, - ) -> Pin< - Box>> + Send>, - > + Send + ) + -> Pin>>>> + + Send + Sync + Clone, O: GiteratedObject + Send + Sync, @@ -117,7 +123,7 @@ where } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl GiteratedOperationHandler<(O1,), O, D, S> for F where F: FnMut( @@ -125,9 +131,9 @@ where D, S, O1, - ) -> Pin< - Box>> + Send>, - > + Send + ) + -> Pin>>>> + + Send + Sync + Clone, O: GiteratedObject + Send + Sync, @@ -158,7 +164,7 @@ where } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl GiteratedOperationHandler<(O1, O2), O, D, S> for F where F: FnMut( @@ -167,9 +173,9 @@ where S, O1, O2, - ) -> Pin< - Box>> + Send>, - > + Send + ) + -> Pin>>>> + + Send + Sync + Clone, O: GiteratedObject + Send + Sync, @@ -204,7 +210,7 @@ where } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl GiteratedOperationHandler<(O1, O2, O3), O, D, S> for F where F: FnMut( @@ -214,9 +220,9 @@ where O1, O2, O3, - ) -> Pin< - Box>> + Send>, - > + Send + ) + -> Pin>>>> + + Send + Sync + Clone, O: GiteratedObject + Send + Sync, @@ -258,13 +264,20 @@ where pub struct OperationWrapper { func: Box< dyn Fn( - Box, - Box, + &(dyn Any + Send + Sync), + &(dyn Any + Send + Sync), &(dyn Any + Send + Sync), StackOperationState, - ) - -> Pin, OperationError>>> + Send>> - + Send + ) -> Pin< + Box< + dyn Future< + Output = Result< + Box, + OperationError>, + >, + >, + >, + > + Send + Sync, >, state: Box, @@ -274,31 +287,34 @@ impl OperationWrapper { pub fn new< A, O: GiteratedObject + Send + Sync + 'static, - D: GiteratedOperation + 'static, - F: GiteratedOperationHandler + Send + Sync + 'static + Clone, + D: GiteratedOperation + 'static + Clone, + F: GiteratedOperationHandler + 'static + Send + Sync + Clone, S: GiteratedStackState + 'static, >( handler: F, state: S, - ) -> Self { - let handler = Arc::new(Box::pin(handler)); + ) -> Self + where + D::Failure: Send + Sync, + D::Success: Send + Sync, + { Self { func: Box::new(move |object, operation, state, operation_state| { let handler = handler.clone(); let state = state.downcast_ref::().unwrap().clone(); + let object: &O = object.downcast_ref().unwrap(); + let operation: &D = operation.downcast_ref().unwrap(); + let object = object.clone(); + let operation = operation.clone(); async move { - let handler = handler.clone(); - let object: Box = object.downcast().unwrap(); - let operation: Box = operation.downcast().unwrap(); - let result = handler - .handle(&object, *operation, state, &operation_state) + .handle(&object, operation, state, &operation_state) .await; result - .map(|success| serde_json::to_vec(&success).unwrap()) + .map(|success| Box::new(success) as _) .map_err(|err| match err { OperationError::Operation(err) => { - OperationError::Operation(serde_json::to_vec(&err).unwrap()) + OperationError::Operation(Box::new(err) as _) } OperationError::Internal(internal) => { OperationError::Internal(internal) @@ -306,7 +322,7 @@ impl OperationWrapper { OperationError::Unhandled => OperationError::Unhandled, }) } - .boxed() + .boxed_local() }), state: Box::new(state), } @@ -314,18 +330,22 @@ impl OperationWrapper { async fn handle( &self, - object: Box, - operation: Box, + object: &Box, + operation: &Box, operation_state: &StackOperationState, - ) -> Result, OperationError>> { - (self.func)(object, operation, &self.state, operation_state.clone()).await + ) -> Result, OperationError>> { + (self.func)( + (*object).as_ref(), + (*operation).as_ref(), + self.state.as_ref(), + operation_state.clone(), + ) + .await } } -#[async_trait::async_trait] -pub trait FromOperationState + Send + Sync>: - Sized + Clone + Send -{ +#[async_trait::async_trait(?Send)] +pub trait FromOperationState>: Sized + Clone { type Error: Serialize + DeserializeOwned; async fn from_state( @@ -335,8 +355,8 @@ pub trait FromOperationState + Send ) -> Result>; } -#[async_trait::async_trait] -impl + Send + Sync> FromOperationState +#[async_trait::async_trait(?Send)] +impl> FromOperationState for Arc { type Error = (); @@ -350,8 +370,8 @@ impl + Send + Sync> FromOperationSt } } -#[async_trait::async_trait] -impl + Send + Sync> FromOperationState +#[async_trait::async_trait(?Send)] +impl> FromOperationState for StackOperationState { type Error = (); @@ -365,7 +385,7 @@ impl + Send + Sync> FromOperationSt } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl + Send + Sync> FromOperationState for AuthenticatedUser { @@ -383,7 +403,7 @@ impl + Send + Sync> FromOperationSt } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl + Send + Sync> FromOperationState for AuthenticatedInstance { @@ -401,7 +421,7 @@ impl + Send + Sync> FromOperationSt } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl< T: FromOperationState + Send + Sync, O: GiteratedObject + Sync, @@ -425,7 +445,7 @@ pub struct AuthorizedUser(AuthenticatedUser); #[derive(Clone)] pub struct AuthorizedInstance(AuthenticatedInstance); -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] pub trait AuthorizedOperation: GiteratedOperation { async fn authorize( &self, @@ -434,12 +454,8 @@ pub trait AuthorizedOperation: GiteratedOperation { ) -> Result>; } -#[async_trait::async_trait] -impl< - O: GiteratedObject + Send + Sync + Debug, - V: GiteratedObjectValue + Send + Sync, - > AuthorizedOperation for GetValue -{ +#[async_trait::async_trait(?Send)] +impl AuthorizedOperation for GetValue { async fn authorize( &self, authorize_for: &O, @@ -453,7 +469,7 @@ impl< } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for SetSetting { async fn authorize( &self, @@ -469,7 +485,7 @@ impl AuthorizedOperation for SetSetting { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for GetSetting { async fn authorize( &self, @@ -485,7 +501,7 @@ impl AuthorizedOperation for GetSetting { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for SetSetting { async fn authorize( &self, @@ -521,7 +537,7 @@ impl AuthorizedOperation for SetSetting { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for GetSetting { async fn authorize( &self, @@ -557,7 +573,7 @@ impl AuthorizedOperation for GetSetting { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for RegisterAccountRequest { async fn authorize( &self, @@ -572,7 +588,7 @@ impl AuthorizedOperation for RegisterAccountRequest { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for AuthenticationTokenRequest { async fn authorize( &self, @@ -587,7 +603,7 @@ impl AuthorizedOperation for AuthenticationTokenRequest { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl AuthorizedOperation for RepositoryCreateRequest { async fn authorize( &self, @@ -602,7 +618,7 @@ impl AuthorizedOperation for RepositoryCreateRequest { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl + Send + Sync> FromOperationState for AuthorizedUser { type Error = (); @@ -624,7 +640,7 @@ impl + Send + Sync> FromOperationState for } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl + Send + Sync> FromOperationState for AuthorizedInstance {