pub mod handler; pub mod state; use std::{collections::HashMap, future::Future, pin::Pin, str::FromStr, sync::Arc}; use futures_util::FutureExt; use giterated_models::{ error::OperationError, instance::Instance, object::{ AnyObject, GiteratedObject, Object, ObjectRequest, ObjectRequestError, ObjectResponse, }, object_backend::ObjectBackend, operation::{AnyOperation, GiteratedOperation}, repository::Repository, user::User, }; use handler::GiteratedBackend; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; use state::HandlerState; use tokio::{sync::mpsc::channel, task::JoinHandle}; use tracing::{error, warn}; #[derive(Clone, Debug, Hash, Eq, PartialEq)] struct ObjectOperationPair { object_name: String, operation_name: String, } pub struct OperationHandlers { operations: HashMap>, get_object: Vec>, } impl Default for OperationHandlers { fn default() -> Self { Self { operations: HashMap::new(), get_object: Vec::new(), } } } impl OperationHandlers { pub fn insert< A, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, H: GiteratedOperationHandler + Send + Sync + 'static + Clone, >( &mut self, handler: H, ) -> &mut Self { let object_name = handler.object_name().to_string(); let operation_name = handler.operation_name().to_string(); let wrapped = OperationWrapper::new(handler); let pair = ObjectOperationPair { object_name, operation_name, }; assert!(self.operations.insert(pair, wrapped).is_none()); self } pub fn register_object(&mut self) -> &mut Self { let closure = |_: &Instance, operation: ObjectRequest, _state| { async move { if O::from_str(&operation.0).is_ok() { Ok(ObjectResponse(operation.0)) } else { Err(OperationError::Unhandled) } } .boxed() }; let wrapped = OperationWrapper::new(closure); self.get_object.push(wrapped); self } pub async fn handle( &self, object: &O, operation_name: &str, operation: AnyOperation, state: S, operation_state: &StackOperationState, ) -> Result, OperationError>> { // TODO let object = object.to_string(); let object_name = { if User::from_str(&object).is_ok() { User::object_name() } else if Repository::from_str(&object).is_ok() { Repository::object_name() } else if Instance::from_str(&object).is_ok() { Instance::object_name() } else { return Err(OperationError::Unhandled); } } .to_string(); let target_handler = ObjectOperationPair { object_name, operation_name: operation_name.to_string(), }; if let Some(handler) = self.operations.get(&target_handler) { handler .handle( AnyObject(object.to_string()), operation.clone(), state.clone(), operation_state, ) .await } else { Err(OperationError::Unhandled) } } pub async fn resolve_object( &self, instance: AnyObject, request: ObjectRequest, state: S, operation_state: &StackOperationState, ) -> Result, OperationError>> { for handler in self.get_object.iter() { if let Ok(response) = handler .handle( instance.clone(), AnyOperation(serde_json::to_value(request.clone()).unwrap()), state.clone(), operation_state, ) .await { return Ok(response); } } Err(OperationError::Unhandled) } } #[async_trait::async_trait] pub trait GiteratedOperationHandler< L, O: GiteratedObject, D: GiteratedOperation, S: Send + Sync + Clone, > { fn operation_name(&self) -> &str; fn object_name(&self) -> &str; async fn handle( &self, object: &O, operation: D, state: S, operation_state: &StackOperationState, ) -> Result>; } #[async_trait::async_trait] impl GiteratedOperationHandler<(), O, D, S> for F where F: FnMut( &O, D, S, ) -> Pin< Box>> + Send>, > + Send + Sync + Clone, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, >::Failure: Send, S: Send + Sync + Clone + 'static, { fn operation_name(&self) -> &str { D::operation_name() } fn object_name(&self) -> &str { O::object_name() } async fn handle( &self, object: &O, operation: D, state: S, _operation_state: &StackOperationState, ) -> Result> { self.clone()(object, operation, state).await } } #[async_trait::async_trait] impl GiteratedOperationHandler<(O1,), O, D, S> for F where F: FnMut( &O, D, S, O1, ) -> Pin< Box>> + Send>, > + Send + Sync + Clone, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, >::Failure: Send, S: Send + Sync + Clone + 'static, O1: FromOperationState, { fn operation_name(&self) -> &str { D::operation_name() } fn object_name(&self) -> &str { O::object_name() } async fn handle( &self, object: &O, operation: D, state: S, operation_state: &StackOperationState, ) -> Result> { let o1 = O1::from_state(operation_state) .await .map_err(|e| OperationError::Internal(e.to_string()))?; self.clone()(object, operation, state, o1).await } } #[async_trait::async_trait] impl GiteratedOperationHandler<(O1, O2), O, D, S> for F where F: FnMut( &O, D, S, O1, O2, ) -> Pin< Box>> + Send>, > + Send + Sync + Clone, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, >::Failure: Send, S: Send + Sync + Clone + 'static, O1: FromOperationState, O2: FromOperationState, { fn operation_name(&self) -> &str { D::operation_name() } fn object_name(&self) -> &str { O::object_name() } async fn handle( &self, object: &O, operation: D, state: S, operation_state: &StackOperationState, ) -> Result> { let o1 = O1::from_state(operation_state) .await .map_err(|e| OperationError::Internal(e.to_string()))?; let o2 = O2::from_state(operation_state) .await .map_err(|e| OperationError::Internal(e.to_string()))?; self.clone()(object, operation, state, o1, o2).await } } pub struct OperationWrapper { func: Box< dyn Fn( AnyObject, AnyOperation, S, StackOperationState, ) -> Pin, OperationError>>> + Send>> + Send + Sync, >, object_name: String, } impl OperationWrapper { pub fn new< A, O: GiteratedObject + Send + Sync, D: GiteratedOperation + 'static, F: GiteratedOperationHandler + Send + Sync + 'static + Clone, >( handler: F, ) -> Self { let handler = Arc::new(Box::pin(handler)); Self { func: Box::new(move |any_object, any_operation, state, operation_state| { let handler = handler.clone(); async move { let handler = handler.clone(); let object: O = O::from_object_str(&any_object.0).map_err(|_| OperationError::Unhandled)?; let operation: D = serde_json::from_value(any_operation.0.clone()) .map_err(|_| OperationError::Unhandled)?; let result = handler .handle(&object, operation, state, &operation_state) .await; result .map(|success| serde_json::to_vec(&success).unwrap()) .map_err(|err| match err { OperationError::Operation(err) => { OperationError::Operation(serde_json::to_vec(&err).unwrap()) } OperationError::Internal(internal) => { OperationError::Internal(internal) } OperationError::Unhandled => OperationError::Unhandled, }) } .boxed() }), object_name: O::object_name().to_string(), } } async fn handle( &self, object: AnyObject, operation: AnyOperation, state: S, operation_state: &StackOperationState, ) -> Result, OperationError>> { (self.func)(object, operation, state, operation_state.clone()).await } } #[async_trait::async_trait] pub trait FromOperationState: Sized + Clone + Send { type Error: Serialize + DeserializeOwned; async fn from_state(state: &StackOperationState) -> Result>; } #[async_trait::async_trait] impl FromOperationState for BackendWrapper { type Error = (); async fn from_state(state: &StackOperationState) -> Result> { Ok(state.giterated_backend.clone()) } } #[async_trait::async_trait] impl FromOperationState for StackOperationState { type Error = (); async fn from_state( state: &StackOperationState, ) -> Result> { Ok(state.clone()) } } #[derive(Clone)] pub struct StackOperationState { pub giterated_backend: BackendWrapper, } #[derive(Clone)] pub struct BackendWrapper { sender: tokio::sync::mpsc::Sender<( tokio::sync::oneshot::Sender>>, WrappedOperation, )>, task: Arc>, } pub struct WrappedOperation { object: AnyObject, operation_payload: AnyOperation, operation_name: String, state: StackOperationState, } impl BackendWrapper { pub fn new(backend: GiteratedBackend) -> Self { // Spawn listener task let (send, mut recv) = channel::<( tokio::sync::oneshot::Sender>>, WrappedOperation, )>(1024); let task = tokio::spawn(async move { while let Some((responder, message)) = recv.recv().await { let raw_result = backend .object_operation( message.object, &message.operation_name, message.operation_payload, &message.state, ) .await; responder.send(raw_result).unwrap(); } error!("Error, thing's dead"); }); Self { sender: send, task: Arc::new(task), } } pub async fn call(&self, operation: WrappedOperation) -> Result> { let (sender, response) = tokio::sync::oneshot::channel(); self.sender .send((sender, operation)) .await .map_err(|e| OperationError::Internal(e.to_string()))?; match response.await { Ok(result) => Ok(result?), Err(err) => Err(OperationError::Internal(err.to_string())), } } } use std::fmt::Debug; #[async_trait::async_trait] impl ObjectBackend for BackendWrapper { async fn object_operation( &self, object: O, operation: &str, payload: D, operation_state: &StackOperationState, ) -> Result> where O: GiteratedObject + Debug, D: GiteratedOperation + Debug, { let operation = WrappedOperation { object: AnyObject(object.to_string()), operation_name: operation.to_string(), operation_payload: AnyOperation(serde_json::to_value(payload).unwrap()), state: operation_state.clone(), }; let raw_result = self.call(operation).await; match raw_result { Ok(result) => Ok(serde_json::from_value(result) .map_err(|e| OperationError::Internal(e.to_string()))?), Err(err) => match err { OperationError::Internal(internal) => { warn!( "Internal Error: {:?}", OperationError::<()>::Internal(internal.clone()) ); Err(OperationError::Internal(internal)) } OperationError::Unhandled => Err(OperationError::Unhandled), OperationError::Operation(err) => Err(OperationError::Operation( serde_json::from_value(err) .map_err(|e| OperationError::Internal(e.to_string()))?, )), }, } } async fn get_object( &self, object_str: &str, operation_state: &StackOperationState, ) -> Result, OperationError> { let operation = WrappedOperation { object: AnyObject(object_str.to_string()), operation_name: ObjectRequest::operation_name().to_string(), operation_payload: AnyOperation( serde_json::to_value(ObjectRequest(object_str.to_string())).unwrap(), ), state: operation_state.clone(), }; let raw_result = self.call(operation).await; let object: ObjectResponse = match raw_result { Ok(result) => Ok(serde_json::from_value(result) .map_err(|e| OperationError::Internal(e.to_string()))?), Err(err) => match err { OperationError::Internal(internal) => { warn!( "Internal Error: {:?}", OperationError::<()>::Internal(internal.clone()) ); Err(OperationError::Internal(internal)) } OperationError::Unhandled => Err(OperationError::Unhandled), OperationError::Operation(err) => Err(OperationError::Operation( serde_json::from_value(err) .map_err(|e| OperationError::Internal(e.to_string()))?, )), }, }?; unsafe { Ok(Object::new_unchecked( O::from_str(&object.0) .map_err(|_| OperationError::Internal("deserialize failure".to_string()))?, self.clone(), )) } } }