diff --git a/Cargo.lock b/Cargo.lock index 42aa525..d746042 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -824,6 +824,7 @@ dependencies = [ "giterated-models", "serde", "serde_json", + "tokio", "tracing", ] diff --git a/giterated-daemon/src/cache_backend.rs b/giterated-daemon/src/cache_backend.rs index 31bd28b..8b13789 100644 --- a/giterated-daemon/src/cache_backend.rs +++ b/giterated-daemon/src/cache_backend.rs @@ -1,30 +1 @@ -use giterated_models::error::OperationError; -use giterated_models::object::{GiteratedObject, Object, ObjectRequestError}; -use giterated_models::object_backend::ObjectBackend; -use giterated_models::operation::GiteratedOperation; - -use std::fmt::Debug; - -#[derive(Clone, Debug)] -pub struct CacheBackend; - -#[async_trait::async_trait] -impl ObjectBackend for CacheBackend { - async fn object_operation + Debug>( - &self, - _object: O, - _operation: &str, - _payload: D, - ) -> Result> { - // We don't handle operations with this backend - Err(OperationError::Unhandled) - } - - async fn get_object( - &self, - _object_str: &str, - ) -> Result, OperationError> { - Err(OperationError::Unhandled) - } -} diff --git a/giterated-daemon/src/connection/wrapper.rs b/giterated-daemon/src/connection/wrapper.rs index 3b95032..f444374 100644 --- a/giterated-daemon/src/connection/wrapper.rs +++ b/giterated-daemon/src/connection/wrapper.rs @@ -14,6 +14,7 @@ use giterated_models::{ authenticated::AuthenticatedPayload, message::GiteratedMessage, object::AnyObject, operation::AnyOperation, }; +use giterated_stack::{handler::GiteratedBackend, StackOperationState}; use serde::Serialize; use tokio::{net::TcpStream, sync::Mutex}; @@ -41,7 +42,8 @@ pub async fn connection_wrapper( instance: impl ToOwned, instance_connections: Arc>, config: Table, - backend: DatabaseBackend, + backend: GiteratedBackend, + operation_state: StackOperationState, ) { let connection_state = ConnectionState { socket: Arc::new(Mutex::new(socket)), @@ -60,8 +62,6 @@ pub async fn connection_wrapper( let _handshaked = false; - let backend = backend.into_backend(); - loop { let mut socket = connection_state.socket.lock().await; let message = socket.next().await; @@ -86,7 +86,12 @@ pub async fn connection_wrapper( let message: GiteratedMessage = message.into_message(); let result = backend - .object_operation(message.object, &message.operation, message.payload) + .object_operation( + message.object, + &message.operation, + message.payload, + &operation_state, + ) .await; // Map result to Vec on both diff --git a/giterated-daemon/src/database_backend/handler.rs b/giterated-daemon/src/database_backend/handler.rs index b93fa93..0b2e180 100644 --- a/giterated-daemon/src/database_backend/handler.rs +++ b/giterated-daemon/src/database_backend/handler.rs @@ -14,6 +14,7 @@ use giterated_models::{ user::{User, UserRepositoriesRequest}, value::{AnyValue, GetValue}, }; +use giterated_stack::{BackendWrapper, StackOperationState}; use super::DatabaseBackend; @@ -97,13 +98,14 @@ pub fn repository_info( object: &Repository, operation: RepositoryInfoRequest, state: DatabaseBackend, + operation_state: StackOperationState, + backend: BackendWrapper, ) -> BoxFuture<'static, Result>> { let object = object.clone(); - let backend = state.into_backend(); async move { let mut object = backend - .get_object::(&object.to_string()) + .get_object::(&object.to_string(), &operation_state) .await .unwrap(); @@ -125,17 +127,17 @@ pub fn repository_info( let info = RepositoryView { name: object.object().name.clone(), owner: object.object().owner.clone(), - description: object.get::().await.ok(), + description: object.get::(&operation_state).await.ok(), visibility: object - .get::() + .get::(&operation_state) .await .map_err(|e| OperationError::Internal(format!("{:?}: {}", e.source(), e)))?, default_branch: object - .get::() + .get::(&operation_state) .await .map_err(|e| OperationError::Internal(format!("{:?}: {}", e.source(), e)))?, // TODO: Can't be a simple get function, this needs to be returned alongside the tree as this differs depending on the rev and path. - latest_commit: object.get::().await.ok(), + latest_commit: object.get::(&operation_state).await.ok(), tree_rev: operation.rev, tree, }; @@ -149,13 +151,14 @@ pub fn repository_file_from_id( object: &Repository, operation: RepositoryFileFromIdRequest, state: DatabaseBackend, + operation_state: StackOperationState, + backend: BackendWrapper, ) -> BoxFuture<'static, Result>> { let object = object.clone(); - let backend = state.into_backend(); async move { let object = backend - .get_object::(&object.to_string()) + .get_object::(&object.to_string(), &operation_state) .await .unwrap(); @@ -179,13 +182,14 @@ pub fn repository_diff( object: &Repository, operation: RepositoryDiffRequest, state: DatabaseBackend, + operation_state: StackOperationState, + backend: BackendWrapper, ) -> BoxFuture<'static, Result>> { let object = object.clone(); - let backend = state.into_backend(); async move { let object = backend - .get_object::(&object.to_string()) + .get_object::(&object.to_string(), &operation_state) .await .unwrap(); @@ -205,13 +209,14 @@ pub fn repository_commit_before( object: &Repository, operation: RepositoryCommitBeforeRequest, state: DatabaseBackend, + operation_state: StackOperationState, + backend: BackendWrapper, ) -> BoxFuture<'static, Result>> { let object = object.clone(); - let backend = state.into_backend(); async move { let object = backend - .get_object::(&object.to_string()) + .get_object::(&object.to_string(), &operation_state) .await .unwrap(); diff --git a/giterated-daemon/src/database_backend/mod.rs b/giterated-daemon/src/database_backend/mod.rs index 7d75a07..d32f1e3 100644 --- a/giterated-daemon/src/database_backend/mod.rs +++ b/giterated-daemon/src/database_backend/mod.rs @@ -10,7 +10,7 @@ use giterated_models::operation::GiteratedOperation; use giterated_models::repository::Repository; use giterated_models::user::User; use giterated_stack::handler::GiteratedBackend; -use giterated_stack::OperationHandlers; +use giterated_stack::{OperationHandlers, StackOperationState}; use std::fmt::Debug; use tokio::sync::Mutex; @@ -26,12 +26,13 @@ use self::handler::{ pub struct Foobackend {} #[async_trait::async_trait] -impl ObjectBackend for Foobackend { +impl ObjectBackend for Foobackend { async fn object_operation + Debug>( &self, _object: O, _operation: &str, _payload: D, + _operation_state: &StackOperationState, ) -> Result> { // We don't handle operations with this backend Err(OperationError::Unhandled) @@ -40,7 +41,8 @@ impl ObjectBackend for Foobackend { async fn get_object( &self, _object_str: &str, - ) -> Result, OperationError> { + _operation_state: &StackOperationState, + ) -> Result, OperationError> { Err(OperationError::Unhandled) } } @@ -121,6 +123,7 @@ mod test { use giterated_models::user::{DisplayName, User}; use giterated_models::value::{AnyValue, GiteratedObjectValue}; use giterated_stack::handler::GiteratedBackend; + use giterated_stack::StackOperationState; use serde_json::Value; use tokio::sync::Mutex; @@ -274,16 +277,21 @@ mod test { .into_backend() } + fn operation_state() -> StackOperationState { + todo!() + } + #[tokio::test] async fn test_user_get() { let backend = test_backend(); + let operation_state = operation_state(); let mut user = backend - .get_object::("test_user:test.giterated.dev") + .get_object::("test_user:test.giterated.dev", &operation_state) .await .expect("object should have been returned"); - user.get::() + user.get::(&operation_state) .await .expect("object value should have been returned"); } @@ -291,13 +299,14 @@ mod test { #[tokio::test] async fn test_user_get_setting() { let backend = test_backend(); + let operation_state = operation_state(); let mut user = backend - .get_object::("test_user:test.giterated.dev") + .get_object::("test_user:test.giterated.dev", &operation_state) .await .expect("object should have been returned"); - user.get_setting::() + user.get_setting::(&operation_state) .await .expect("object value should have been returned"); } @@ -305,13 +314,14 @@ mod test { #[tokio::test] async fn test_user_set_setting() { let backend = test_backend(); + let operation_state = operation_state(); let mut user = backend - .get_object::("test_user:test.giterated.dev") + .get_object::("test_user:test.giterated.dev", &operation_state) .await .expect("object should have been returned"); - user.set_setting::(DisplayName(String::from("test"))) + user.set_setting::(DisplayName(String::from("test")), &operation_state) .await .expect("object value should have been returned"); } @@ -319,14 +329,18 @@ mod test { #[tokio::test] async fn test_respository_get() { let backend = test_backend(); + let operation_state = operation_state(); let mut repository = backend - .get_object::("test_user:test.giterated.dev/repository@test.giterated.dev") + .get_object::( + "test_user:test.giterated.dev/repository@test.giterated.dev", + &operation_state, + ) .await .expect("object should have been returned"); repository - .get::() + .get::(&operation_state) .await .expect("object value should have been returned"); } @@ -334,14 +348,18 @@ mod test { #[tokio::test] async fn test_repository_get_setting() { let backend = test_backend(); + let operation_state = operation_state(); let mut repository = backend - .get_object::("test_user:test.giterated.dev/repository@test.giterated.dev") + .get_object::( + "test_user:test.giterated.dev/repository@test.giterated.dev", + &operation_state, + ) .await .expect("object should have been returned"); repository - .get_setting::() + .get_setting::(&operation_state) .await .expect("object value should have been returned"); } @@ -349,14 +367,18 @@ mod test { #[tokio::test] async fn test_repository_set_setting() { let backend = test_backend(); + let operation_state = operation_state(); let mut repository = backend - .get_object::("test_user:test.giterated.dev/repository@test.giterated.dev") + .get_object::( + "test_user:test.giterated.dev/repository@test.giterated.dev", + &operation_state, + ) .await .expect("object should have been returned"); repository - .set_setting::(Description(String::from("test"))) + .set_setting::(Description(String::from("test")), &operation_state) .await .expect("object value should have been returned"); } diff --git a/giterated-daemon/src/main.rs b/giterated-daemon/src/main.rs index d2b8b75..4298a0b 100644 --- a/giterated-daemon/src/main.rs +++ b/giterated-daemon/src/main.rs @@ -12,6 +12,7 @@ use giterated_daemon::{ use giterated_models::instance::Instance; +use giterated_stack::{BackendWrapper, StackOperationState}; use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool}; use std::{net::SocketAddr, str::FromStr, sync::Arc}; use tokio::{ @@ -88,6 +89,16 @@ async fn main() -> Result<(), Error> { repository_backend.clone(), ); + let backend = database_backend.into_backend(); + + let backend_wrapper = BackendWrapper::new(backend.clone()); + + let operation_state = { + StackOperationState { + giterated_backend: backend_wrapper, + } + }; + loop { let stream = accept_stream(&mut listener).await; info!("Connected"); @@ -129,7 +140,8 @@ async fn main() -> Result<(), Error> { Instance::from_str(config["giterated"]["instance"].as_str().unwrap()).unwrap(), instance_connections.clone(), config.clone(), - database_backend.clone(), + backend.clone(), + operation_state.clone(), )), }; diff --git a/giterated-models/src/instance/operations.rs b/giterated-models/src/instance/operations.rs index 62271f3..c9404a6 100644 --- a/giterated-models/src/instance/operations.rs +++ b/giterated-models/src/instance/operations.rs @@ -106,18 +106,22 @@ impl GiteratedOperation for RepositoryCreateRequest { type Failure = InstanceError; } -impl Object<'_, Instance, B> { +impl + std::fmt::Debug> Object<'_, S, Instance, B> { pub async fn register_account( &mut self, email: Option<&str>, username: &str, password: &Secret, + operation_state: &S, ) -> Result> { - self.request::(RegisterAccountRequest { - username: username.to_string(), - email: email.map(|s| s.to_string()), - password: password.clone(), - }) + self.request::( + RegisterAccountRequest { + username: username.to_string(), + email: email.map(|s| s.to_string()), + password: password.clone(), + }, + operation_state, + ) .await } @@ -125,12 +129,16 @@ impl Object<'_, Instance, B> { &mut self, username: &str, password: &Secret, + operation_state: &S, ) -> Result> { - self.request::(AuthenticationTokenRequest { - instance: self.inner.clone(), - username: username.to_string(), - password: password.clone(), - }) + self.request::( + AuthenticationTokenRequest { + instance: self.inner.clone(), + username: username.to_string(), + password: password.clone(), + }, + operation_state, + ) .await } @@ -139,22 +147,30 @@ impl Object<'_, Instance, B> { instance: &Instance, username: &str, password: &Secret, + operation_state: &S, ) -> Result> { - self.request::(AuthenticationTokenRequest { - instance: instance.clone(), - username: username.to_string(), - password: password.clone(), - }) + self.request::( + AuthenticationTokenRequest { + instance: instance.clone(), + username: username.to_string(), + password: password.clone(), + }, + operation_state, + ) .await } pub async fn token_extension( &mut self, token: &UserAuthenticationToken, + operation_state: &S, ) -> Result, OperationError> { - self.request::(TokenExtensionRequest { - token: token.clone(), - }) + self.request::( + TokenExtensionRequest { + token: token.clone(), + }, + operation_state, + ) .await } @@ -165,15 +181,19 @@ impl Object<'_, Instance, B> { visibility: &RepositoryVisibility, default_branch: &str, owner: &User, + operation_state: &S, ) -> Result> { - self.request::(RepositoryCreateRequest { - instance: Some(instance.clone()), - name: name.to_string(), - description: None, - visibility: visibility.clone(), - default_branch: default_branch.to_string(), - owner: owner.clone(), - }) + self.request::( + RepositoryCreateRequest { + instance: Some(instance.clone()), + name: name.to_string(), + description: None, + visibility: visibility.clone(), + default_branch: default_branch.to_string(), + owner: owner.clone(), + }, + operation_state, + ) .await } } diff --git a/giterated-models/src/object.rs b/giterated-models/src/object.rs index 6cac03d..20d0be4 100644 --- a/giterated-models/src/object.rs +++ b/giterated-models/src/object.rs @@ -18,18 +18,25 @@ mod operations; pub use operations::*; #[derive(Debug, Clone)] -pub struct Object<'b, O: GiteratedObject, B: ObjectBackend + 'b + Send + Sync + Clone> { +pub struct Object< + 'b, + S: Clone + Send + Sync, + O: GiteratedObject, + B: ObjectBackend + 'b + Send + Sync + Clone, +> { pub(crate) inner: O, pub(crate) backend: B, - pub(crate) _marker: PhantomData<&'b ()>, + pub(crate) _marker: PhantomData<&'b S>, } -impl<'b, B: ObjectBackend + Send + Sync + Clone, O: GiteratedObject> Object<'b, O, B> { +impl<'b, S: Clone + Send + Sync, B: ObjectBackend + Send + Sync + Clone, O: GiteratedObject> + Object<'b, S, O, B> +{ pub fn object(&self) -> &O { &self.inner } - pub unsafe fn new_unchecked(object: O, backend: B) -> Object<'b, O, B> { + pub unsafe fn new_unchecked(object: O, backend: B) -> Object<'b, S, O, B> { Object { inner: object, backend, @@ -38,8 +45,11 @@ impl<'b, B: ObjectBackend + Send + Sync + Clone, O: GiteratedObject> Object<'b, } } -impl Display - for Object<'_, O, B> +impl< + S: Clone + Send + Sync, + O: GiteratedObject + Display, + B: ObjectBackend + Send + Sync + Clone, + > Display for Object<'_, S, O, B> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.inner.fmt(f) @@ -52,15 +62,21 @@ pub trait GiteratedObject: Send + Display + FromStr { fn from_object_str(object_str: &str) -> Result; } -impl<'b, O: GiteratedObject + Clone + Debug, B: ObjectBackend> Object<'b, O, B> { +impl<'b, I: Clone + Send + Sync, O: GiteratedObject + Clone + Debug, B: ObjectBackend> + Object<'b, I, O, B> +{ pub async fn get + Send + Debug>( &mut self, + operation_state: &I, ) -> Result> { let result = self - .request(GetValue { - value_name: V::value_name().to_string(), - _marker: PhantomData, - }) + .request( + GetValue { + value_name: V::value_name().to_string(), + _marker: PhantomData, + }, + operation_state, + ) .await; result @@ -68,31 +84,45 @@ impl<'b, O: GiteratedObject + Clone + Debug, B: ObjectBackend> Object<'b, O, B> pub async fn get_setting( &mut self, + operation_state: &I, ) -> Result> { - self.request(GetSetting { - setting_name: S::name().to_string(), - _marker: PhantomData, - }) + self.request( + GetSetting { + setting_name: S::name().to_string(), + _marker: PhantomData, + }, + operation_state, + ) .await } pub async fn set_setting( &mut self, setting: S, + operation_state: &I, ) -> Result<(), OperationError> { - self.request(SetSetting { - setting_name: S::name().to_string(), - value: setting, - }) + self.request( + SetSetting { + setting_name: S::name().to_string(), + value: setting, + }, + operation_state, + ) .await } pub async fn request + Debug>( &mut self, request: R, + operation_state: &I, ) -> Result> { self.backend - .object_operation(self.inner.clone(), R::operation_name(), request) + .object_operation( + self.inner.clone(), + R::operation_name(), + request, + operation_state, + ) .await } } diff --git a/giterated-models/src/object_backend.rs b/giterated-models/src/object_backend.rs index 4ae7897..6af7bfb 100644 --- a/giterated-models/src/object_backend.rs +++ b/giterated-models/src/object_backend.rs @@ -7,12 +7,13 @@ use crate::{ use std::fmt::Debug; #[async_trait::async_trait] -pub trait ObjectBackend: Send + Sync + Sized + Clone { +pub trait ObjectBackend: Send + Sync + Sized + Clone { async fn object_operation( &self, object: O, operation: &str, payload: D, + operation_state: &S, ) -> Result> where O: GiteratedObject + Debug, @@ -21,5 +22,6 @@ pub trait ObjectBackend: Send + Sync + Sized + Clone { async fn get_object( &self, object_str: &str, - ) -> Result, OperationError>; + operation_state: &S, + ) -> Result, OperationError>; } diff --git a/giterated-models/src/operation.rs b/giterated-models/src/operation.rs index b816649..3b3dc23 100644 --- a/giterated-models/src/operation.rs +++ b/giterated-models/src/operation.rs @@ -24,3 +24,8 @@ impl GiteratedOperation for AnyOperation { type Failure = Value; } + +/// The internal state of an operation, used to provide authentication information +/// and the ability to make giterated calls within handlers. +#[derive(Clone)] +pub struct GiteratedOperationState(pub S); diff --git a/giterated-models/src/repository/operations.rs b/giterated-models/src/repository/operations.rs index 112637b..326d2fa 100644 --- a/giterated-models/src/repository/operations.rs +++ b/giterated-models/src/repository/operations.rs @@ -153,44 +153,60 @@ impl GiteratedOperation for RepositoryFileInspectRequest { type Failure = RepositoryError; } -impl Object<'_, Repository, B> { +impl + std::fmt::Debug> Object<'_, S, Repository, B> { pub async fn info( &mut self, extra_metadata: bool, rev: Option, path: Option, + operation_state: &S, ) -> Result> { - self.request::(RepositoryInfoRequest { - extra_metadata, - rev, - path, - }) + self.request::( + RepositoryInfoRequest { + extra_metadata, + rev, + path, + }, + operation_state, + ) .await } pub async fn file_from_id( &mut self, id: String, + operation_state: &S, ) -> Result> { - self.request::(RepositoryFileFromIdRequest(id)) - .await + self.request::( + RepositoryFileFromIdRequest(id), + operation_state, + ) + .await } pub async fn diff( &mut self, old_id: String, new_id: String, + operation_state: &S, ) -> Result> { - self.request::(RepositoryDiffRequest { old_id, new_id }) - .await + self.request::( + RepositoryDiffRequest { old_id, new_id }, + operation_state, + ) + .await } pub async fn commit_before( &mut self, id: String, + operation_state: &S, ) -> Result> { - self.request::(RepositoryCommitBeforeRequest(id)) - .await + self.request::( + RepositoryCommitBeforeRequest(id), + operation_state, + ) + .await } // pub async fn issues_count(&mut self) -> Result> { @@ -200,15 +216,17 @@ impl Object<'_, Repository, B> { pub async fn issue_labels( &mut self, + operation_state: &S, ) -> Result, OperationError> { - self.request::(RepositoryIssueLabelsRequest) + self.request::(RepositoryIssueLabelsRequest, operation_state) .await } pub async fn issues( &mut self, + operation_state: &S, ) -> Result, OperationError> { - self.request::(RepositoryIssuesRequest) + self.request::(RepositoryIssuesRequest, operation_state) .await } @@ -217,12 +235,16 @@ impl Object<'_, Repository, B> { extra_metadata: bool, rev: Option<&str>, path: Option<&str>, + operation_state: &S, ) -> Result, OperationError> { - self.request::(RepositoryFileInspectRequest { - extra_metadata, - rev: rev.map(|r| r.to_string()), - path: path.map(|p| p.to_string()), - }) + self.request::( + RepositoryFileInspectRequest { + extra_metadata, + rev: rev.map(|r| r.to_string()), + path: path.map(|p| p.to_string()), + }, + operation_state, + ) .await } } diff --git a/giterated-models/src/user/operations.rs b/giterated-models/src/user/operations.rs index c81a7ff..870c5ab 100644 --- a/giterated-models/src/user/operations.rs +++ b/giterated-models/src/user/operations.rs @@ -22,15 +22,19 @@ impl GiteratedOperation for UserRepositoriesRequest { type Failure = UserError; } -impl Object<'_, User, B> { +impl + std::fmt::Debug> Object<'_, S, User, B> { pub async fn repositories( &mut self, instance: &Instance, + operation_state: &S, ) -> Result, OperationError> { - self.request::(UserRepositoriesRequest { - instance: instance.clone(), - user: self.inner.clone(), - }) + self.request::( + UserRepositoriesRequest { + instance: instance.clone(), + user: self.inner.clone(), + }, + operation_state, + ) .await } } diff --git a/giterated-stack/Cargo.toml b/giterated-stack/Cargo.toml index 97c43f6..efc9745 100644 --- a/giterated-stack/Cargo.toml +++ b/giterated-stack/Cargo.toml @@ -12,4 +12,5 @@ serde = { version = "1.0.188", features = [ "derive" ]} serde_json = "1.0" bincode = "*" futures-util = "*" -tracing = "*" \ No newline at end of file +tracing = "*" +tokio = { version = "1.32.0", features = [ "full" ] } \ No newline at end of file diff --git a/giterated-stack/src/handler.rs b/giterated-stack/src/handler.rs index a2c48d3..10cbde7 100644 --- a/giterated-stack/src/handler.rs +++ b/giterated-stack/src/handler.rs @@ -11,6 +11,8 @@ use tracing::warn; use crate::{state::HandlerState, OperationHandlers}; +use crate::StackOperationState; + #[derive(Clone)] pub struct GiteratedBackend { state: S, @@ -24,15 +26,20 @@ impl GiteratedBackend { handlers: Arc::new(handlers), } } + + pub fn state(&self) -> &S { + &self.state + } } #[async_trait::async_trait] -impl ObjectBackend for GiteratedBackend { +impl ObjectBackend for GiteratedBackend { async fn object_operation( &self, object: O, operation: &str, payload: D, + operation_state: &StackOperationState, ) -> Result> where O: GiteratedObject + Debug, @@ -50,6 +57,7 @@ impl ObjectBackend for GiteratedBackend { AnyObject(object.clone()), serde_json::from_value(serialized).unwrap(), self.state.clone(), + &operation_state, ) .await; @@ -81,6 +89,7 @@ impl ObjectBackend for GiteratedBackend { operation, AnyOperation(serialized), self.state.clone(), + &operation_state, ) .await; @@ -108,13 +117,15 @@ impl ObjectBackend for GiteratedBackend { async fn get_object( &self, object_str: &str, - ) -> Result, OperationError> { + operation_state: &StackOperationState, + ) -> Result, OperationError> { let raw_result = self .handlers .resolve_object( AnyObject("giterated.dev".to_string()), ObjectRequest(object_str.to_string()), self.state.clone(), + operation_state, ) .await; diff --git a/giterated-stack/src/lib.rs b/giterated-stack/src/lib.rs index 6f252ea..7438caf 100644 --- a/giterated-stack/src/lib.rs +++ b/giterated-stack/src/lib.rs @@ -7,12 +7,20 @@ use futures_util::FutureExt; use giterated_models::{ error::OperationError, instance::Instance, - object::{AnyObject, GiteratedObject, ObjectRequest, ObjectResponse}, + object::{ + AnyObject, GiteratedObject, Object, ObjectRequest, ObjectRequestError, ObjectResponse, + }, + object_backend::ObjectBackend, operation::{AnyOperation, GiteratedOperation}, repository::Repository, user::User, }; -use tracing::info; +use handler::GiteratedBackend; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json::Value; +use state::HandlerState; +use tokio::{sync::mpsc::channel, task::JoinHandle}; +use tracing::{error, warn}; #[derive(Clone, Debug, Hash, Eq, PartialEq)] struct ObjectOperationPair { @@ -36,9 +44,10 @@ impl Default for OperationHandlers { impl OperationHandlers { pub fn insert< + A, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, - H: GiteratedOperationHandler + Send + Sync + 'static + Clone, + H: GiteratedOperationHandler + Send + Sync + 'static + Clone, >( &mut self, handler: H, @@ -83,6 +92,7 @@ impl OperationHandlers { operation_name: &str, operation: AnyOperation, state: S, + operation_state: &StackOperationState, ) -> Result, OperationError>> { // TODO let object = object.to_string(); @@ -111,6 +121,7 @@ impl OperationHandlers { AnyObject(object.to_string()), operation.clone(), state.clone(), + operation_state, ) .await } else { @@ -123,6 +134,7 @@ impl OperationHandlers { instance: AnyObject, request: ObjectRequest, state: S, + operation_state: &StackOperationState, ) -> Result, OperationError>> { for handler in self.get_object.iter() { if let Ok(response) = handler @@ -130,6 +142,7 @@ impl OperationHandlers { instance.clone(), AnyOperation(serde_json::to_value(request.clone()).unwrap()), state.clone(), + operation_state, ) .await { @@ -143,6 +156,7 @@ impl OperationHandlers { #[async_trait::async_trait] pub trait GiteratedOperationHandler< + L, O: GiteratedObject, D: GiteratedOperation, S: Send + Sync + Clone, @@ -156,11 +170,12 @@ pub trait GiteratedOperationHandler< object: &O, operation: D, state: S, + operation_state: &StackOperationState, ) -> Result>; } #[async_trait::async_trait] -impl GiteratedOperationHandler for F +impl GiteratedOperationHandler<(), O, D, S> for F where F: FnMut( &O, @@ -189,17 +204,106 @@ where object: &O, operation: D, state: S, + _operation_state: &StackOperationState, ) -> Result> { self.clone()(object, operation, state).await } } +#[async_trait::async_trait] +impl GiteratedOperationHandler<(O1,), O, D, S> for F +where + F: FnMut( + &O, + D, + S, + O1, + ) -> Pin< + Box>> + Send>, + > + Send + + Sync + + Clone, + O: GiteratedObject + Send + Sync, + D: GiteratedOperation + 'static, + >::Failure: Send, + S: Send + Sync + Clone + 'static, + O1: FromOperationState, +{ + fn operation_name(&self) -> &str { + D::operation_name() + } + + fn object_name(&self) -> &str { + O::object_name() + } + + async fn handle( + &self, + object: &O, + operation: D, + state: S, + operation_state: &StackOperationState, + ) -> Result> { + let o1 = O1::from_state(operation_state) + .await + .map_err(|e| OperationError::Internal(e.to_string()))?; + self.clone()(object, operation, state, o1).await + } +} + +#[async_trait::async_trait] +impl GiteratedOperationHandler<(O1, O2), O, D, S> for F +where + F: FnMut( + &O, + D, + S, + O1, + O2, + ) -> Pin< + Box>> + Send>, + > + Send + + Sync + + Clone, + O: GiteratedObject + Send + Sync, + D: GiteratedOperation + 'static, + >::Failure: Send, + S: Send + Sync + Clone + 'static, + O1: FromOperationState, + O2: FromOperationState, +{ + fn operation_name(&self) -> &str { + D::operation_name() + } + + fn object_name(&self) -> &str { + O::object_name() + } + + async fn handle( + &self, + object: &O, + operation: D, + state: S, + operation_state: &StackOperationState, + ) -> Result> { + let o1 = O1::from_state(operation_state) + .await + .map_err(|e| OperationError::Internal(e.to_string()))?; + let o2 = O2::from_state(operation_state) + .await + .map_err(|e| OperationError::Internal(e.to_string()))?; + self.clone()(object, operation, state, o1, o2).await + } +} + pub struct OperationWrapper { func: Box< dyn Fn( AnyObject, AnyOperation, S, + StackOperationState, ) -> Pin, OperationError>>> + Send>> + Send @@ -210,15 +314,16 @@ pub struct OperationWrapper { impl OperationWrapper { pub fn new< + A, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, - F: GiteratedOperationHandler + Send + Sync + 'static + Clone, + F: GiteratedOperationHandler + Send + Sync + 'static + Clone, >( handler: F, ) -> Self { let handler = Arc::new(Box::pin(handler)); Self { - func: Box::new(move |any_object, any_operation, state| { + func: Box::new(move |any_object, any_operation, state, operation_state| { let handler = handler.clone(); async move { let handler = handler.clone(); @@ -227,7 +332,9 @@ impl OperationWrapper { let operation: D = serde_json::from_value(any_operation.0.clone()) .map_err(|_| OperationError::Unhandled)?; - let result = handler.handle(&object, operation, state).await; + let result = handler + .handle(&object, operation, state, &operation_state) + .await; result .map(|success| serde_json::to_vec(&success).unwrap()) .map_err(|err| match err { @@ -251,7 +358,193 @@ impl OperationWrapper { object: AnyObject, operation: AnyOperation, state: S, + operation_state: &StackOperationState, ) -> Result, OperationError>> { - (self.func)(object, operation, state).await + (self.func)(object, operation, state, operation_state.clone()).await + } +} + +#[async_trait::async_trait] +pub trait FromOperationState: Sized + Clone + Send { + type Error: Serialize + DeserializeOwned; + + async fn from_state(state: &StackOperationState) -> Result>; +} + +#[async_trait::async_trait] +impl FromOperationState for BackendWrapper { + type Error = (); + + async fn from_state(state: &StackOperationState) -> Result> { + Ok(state.giterated_backend.clone()) + } +} + +#[async_trait::async_trait] +impl FromOperationState for StackOperationState { + type Error = (); + + async fn from_state( + state: &StackOperationState, + ) -> Result> { + Ok(state.clone()) + } +} + +#[derive(Clone)] +pub struct StackOperationState { + pub giterated_backend: BackendWrapper, +} + +#[derive(Clone)] +pub struct BackendWrapper { + sender: tokio::sync::mpsc::Sender<( + tokio::sync::oneshot::Sender>>, + WrappedOperation, + )>, + task: Arc>, +} + +pub struct WrappedOperation { + object: AnyObject, + operation_payload: AnyOperation, + operation_name: String, + state: StackOperationState, +} + +impl BackendWrapper { + pub fn new(backend: GiteratedBackend) -> Self { + // Spawn listener task + + let (send, mut recv) = channel::<( + tokio::sync::oneshot::Sender>>, + WrappedOperation, + )>(1024); + + let task = tokio::spawn(async move { + while let Some((responder, message)) = recv.recv().await { + let raw_result = backend + .object_operation( + message.object, + &message.operation_name, + message.operation_payload, + &message.state, + ) + .await; + + responder.send(raw_result).unwrap(); + } + error!("Error, thing's dead"); + }); + + Self { + sender: send, + task: Arc::new(task), + } + } + + pub async fn call(&self, operation: WrappedOperation) -> Result> { + let (sender, response) = tokio::sync::oneshot::channel(); + + self.sender + .send((sender, operation)) + .await + .map_err(|e| OperationError::Internal(e.to_string()))?; + + match response.await { + Ok(result) => Ok(result?), + Err(err) => Err(OperationError::Internal(err.to_string())), + } + } +} + +use std::fmt::Debug; + +#[async_trait::async_trait] +impl ObjectBackend for BackendWrapper { + async fn object_operation( + &self, + object: O, + operation: &str, + payload: D, + operation_state: &StackOperationState, + ) -> Result> + where + O: GiteratedObject + Debug, + D: GiteratedOperation + Debug, + { + let operation = WrappedOperation { + object: AnyObject(object.to_string()), + operation_name: operation.to_string(), + operation_payload: AnyOperation(serde_json::to_value(payload).unwrap()), + state: operation_state.clone(), + }; + + let raw_result = self.call(operation).await; + + match raw_result { + Ok(result) => Ok(serde_json::from_value(result) + .map_err(|e| OperationError::Internal(e.to_string()))?), + Err(err) => match err { + OperationError::Internal(internal) => { + warn!( + "Internal Error: {:?}", + OperationError::<()>::Internal(internal.clone()) + ); + + Err(OperationError::Internal(internal)) + } + OperationError::Unhandled => Err(OperationError::Unhandled), + OperationError::Operation(err) => Err(OperationError::Operation( + serde_json::from_value(err) + .map_err(|e| OperationError::Internal(e.to_string()))?, + )), + }, + } + } + + async fn get_object( + &self, + object_str: &str, + operation_state: &StackOperationState, + ) -> Result, OperationError> { + let operation = WrappedOperation { + object: AnyObject(object_str.to_string()), + operation_name: ObjectRequest::operation_name().to_string(), + operation_payload: AnyOperation( + serde_json::to_value(ObjectRequest(object_str.to_string())).unwrap(), + ), + state: operation_state.clone(), + }; + + let raw_result = self.call(operation).await; + + let object: ObjectResponse = match raw_result { + Ok(result) => Ok(serde_json::from_value(result) + .map_err(|e| OperationError::Internal(e.to_string()))?), + Err(err) => match err { + OperationError::Internal(internal) => { + warn!( + "Internal Error: {:?}", + OperationError::<()>::Internal(internal.clone()) + ); + + Err(OperationError::Internal(internal)) + } + OperationError::Unhandled => Err(OperationError::Unhandled), + OperationError::Operation(err) => Err(OperationError::Operation( + serde_json::from_value(err) + .map_err(|e| OperationError::Internal(e.to_string()))?, + )), + }, + }?; + + unsafe { + Ok(Object::new_unchecked( + O::from_str(&object.0) + .map_err(|_| OperationError::Internal("deserialize failure".to_string()))?, + self.clone(), + )) + } } } diff --git a/giterated-stack/src/state.rs b/giterated-stack/src/state.rs index 50ca121..07f25a0 100644 --- a/giterated-stack/src/state.rs +++ b/giterated-stack/src/state.rs @@ -1,3 +1,5 @@ +use std::any::Any; + /// A type which can be passed into a stateful handler. /// /// # Trait Bounds @@ -7,6 +9,6 @@ /// # Blanket Impl /// This trait is blanket-impl'd on any type that meets the requirements. You do not need /// to manually mark your state types with it. -pub trait HandlerState: Send + Sync + Clone + 'static {} +pub trait HandlerState: Any + Send + Sync + Clone + 'static {} impl HandlerState for T where T: Send + Sync + Clone + 'static {}