using System; using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Threading.Tasks; using MCPForUnity.Editor.Helpers; using MCPForUnity.Editor.Resources; using Newtonsoft.Json; using Newtonsoft.Json.Linq; namespace MCPForUnity.Editor.Tools { /// /// Holds information about a registered command handler. /// class HandlerInfo { public string CommandName { get; } public Func SyncHandler { get; } public Func> AsyncHandler { get; } public bool IsAsync => AsyncHandler != null; public HandlerInfo(string commandName, Func syncHandler, Func> asyncHandler) { CommandName = commandName; SyncHandler = syncHandler; AsyncHandler = asyncHandler; } } /// /// Registry for all MCP command handlers via reflection. /// Handles both MCP tools and resources. /// public static class CommandRegistry { private static readonly Dictionary _handlers = new(); private static bool _initialized = false; /// /// Initialize and auto-discover all tools and resources marked with /// [McpForUnityTool] or [McpForUnityResource] /// public static void Initialize() { if (_initialized) return; AutoDiscoverCommands(); _initialized = true; } private static string ToSnakeCase(string name) => StringCaseUtility.ToSnakeCase(name); /// /// Auto-discover all types with [McpForUnityTool] or [McpForUnityResource] attributes /// private static void AutoDiscoverCommands() { try { var allTypes = AppDomain.CurrentDomain.GetAssemblies() .Where(a => !a.IsDynamic) .SelectMany(a => { try { return a.GetTypes(); } catch { return new Type[0]; } }) .ToList(); // Discover tools var toolTypes = allTypes.Where(t => t.GetCustomAttribute() != null); int toolCount = 0; foreach (var type in toolTypes) { if (RegisterCommandType(type, isResource: false)) toolCount++; } // Discover resources var resourceTypes = allTypes.Where(t => t.GetCustomAttribute() != null); int resourceCount = 0; foreach (var type in resourceTypes) { if (RegisterCommandType(type, isResource: true)) resourceCount++; } McpLog.Info($"Auto-discovered {toolCount} tools and {resourceCount} resources ({_handlers.Count} total handlers)", false); } catch (Exception ex) { McpLog.Error($"Failed to auto-discover MCP commands: {ex.Message}"); } } /// /// Register a command type (tool or resource) with the registry. /// Returns true if successfully registered, false otherwise. /// private static bool RegisterCommandType(Type type, bool isResource) { string commandName; string typeLabel = isResource ? "resource" : "tool"; // Get command name from appropriate attribute if (isResource) { var resourceAttr = type.GetCustomAttribute(); commandName = resourceAttr.ResourceName; } else { var toolAttr = type.GetCustomAttribute(); commandName = toolAttr.CommandName; } // Auto-generate command name if not explicitly provided if (string.IsNullOrEmpty(commandName)) { commandName = ToSnakeCase(type.Name); } // Check for duplicate command names if (_handlers.ContainsKey(commandName)) { McpLog.Warn( $"Duplicate command name '{commandName}' detected. " + $"{typeLabel} {type.Name} will override previously registered handler." ); } // Find HandleCommand method var method = type.GetMethod( "HandleCommand", BindingFlags.Public | BindingFlags.Static, null, new[] { typeof(JObject) }, null ); if (method == null) { McpLog.Warn( $"MCP {typeLabel} {type.Name} is marked with [McpForUnity{(isResource ? "Resource" : "Tool")}] " + $"but has no public static HandleCommand(JObject) method" ); return false; } try { HandlerInfo handlerInfo; if (typeof(Task).IsAssignableFrom(method.ReturnType)) { var asyncHandler = CreateAsyncHandlerDelegate(method, commandName); handlerInfo = new HandlerInfo(commandName, null, asyncHandler); } else { var handler = (Func)Delegate.CreateDelegate( typeof(Func), method ); handlerInfo = new HandlerInfo(commandName, handler, null); } _handlers[commandName] = handlerInfo; return true; } catch (Exception ex) { McpLog.Error($"Failed to register {typeLabel} {type.Name}: {ex.Message}"); return false; } } /// /// Get a command handler by name /// private static HandlerInfo GetHandlerInfo(string commandName) { if (!_handlers.TryGetValue(commandName, out var handler)) { throw new InvalidOperationException( $"Unknown or unsupported command type: {commandName}" ); } return handler; } /// /// Get a synchronous command handler by name. /// Throws if the command is asynchronous. /// /// /// /// public static Func GetHandler(string commandName) { var handlerInfo = GetHandlerInfo(commandName); if (handlerInfo.IsAsync) { throw new InvalidOperationException( $"Command '{commandName}' is asynchronous and must be executed via ExecuteCommand" ); } return handlerInfo.SyncHandler; } /// /// Execute a command handler, supporting both synchronous and asynchronous (coroutine) handlers. /// If the handler returns an IEnumerator, it will be executed as a coroutine. /// /// The command name to execute /// Command parameters /// TaskCompletionSource to complete when async operation finishes /// The result for synchronous commands, or null for async commands (TCS will be completed later) public static object ExecuteCommand(string commandName, JObject @params, TaskCompletionSource tcs) { var handlerInfo = GetHandlerInfo(commandName); if (handlerInfo.IsAsync) { ExecuteAsyncHandler(handlerInfo, @params, commandName, tcs); return null; } if (handlerInfo.SyncHandler == null) { throw new InvalidOperationException($"Handler for '{commandName}' does not provide a synchronous implementation"); } return handlerInfo.SyncHandler(@params); } /// /// Execute a command handler and return its raw result, regardless of sync or async implementation. /// Used internally for features like batch execution where commands need to be composed. /// /// The registered command to execute. /// Parameters to pass to the command (optional). public static Task InvokeCommandAsync(string commandName, JObject @params) { var handlerInfo = GetHandlerInfo(commandName); var payload = @params ?? new JObject(); if (handlerInfo.IsAsync) { if (handlerInfo.AsyncHandler == null) { throw new InvalidOperationException($"Async handler for '{commandName}' is not configured correctly"); } return handlerInfo.AsyncHandler(payload); } if (handlerInfo.SyncHandler == null) { throw new InvalidOperationException($"Handler for '{commandName}' does not provide a synchronous implementation"); } object result = handlerInfo.SyncHandler(payload); return Task.FromResult(result); } /// /// Create a delegate for an async handler method that returns Task or Task. /// The delegate will invoke the method and await its completion, returning the result. /// /// /// /// /// private static Func> CreateAsyncHandlerDelegate(MethodInfo method, string commandName) { return async (JObject parameters) => { object rawResult; try { rawResult = method.Invoke(null, new object[] { parameters }); } catch (TargetInvocationException ex) { throw ex.InnerException ?? ex; } if (rawResult == null) { return null; } if (rawResult is not Task task) { throw new InvalidOperationException( $"Async handler '{commandName}' returned an object that is not a Task" ); } await task.ConfigureAwait(true); var taskType = task.GetType(); if (taskType.IsGenericType) { var resultProperty = taskType.GetProperty("Result"); if (resultProperty != null) { return resultProperty.GetValue(task); } } return null; }; } private static void ExecuteAsyncHandler( HandlerInfo handlerInfo, JObject parameters, string commandName, TaskCompletionSource tcs) { if (handlerInfo.AsyncHandler == null) { throw new InvalidOperationException($"Async handler for '{commandName}' is not configured correctly"); } Task handlerTask; try { handlerTask = handlerInfo.AsyncHandler(parameters); } catch (Exception ex) { ReportAsyncFailure(commandName, tcs, ex); return; } if (handlerTask == null) { CompleteAsyncCommand(commandName, tcs, null); return; } async void AwaitHandler() { try { var finalResult = await handlerTask.ConfigureAwait(true); CompleteAsyncCommand(commandName, tcs, finalResult); } catch (Exception ex) { ReportAsyncFailure(commandName, tcs, ex); } } AwaitHandler(); } /// /// Complete the TaskCompletionSource for an async command with a success result. /// /// /// /// private static void CompleteAsyncCommand(string commandName, TaskCompletionSource tcs, object result) { try { var response = new { status = "success", result }; string json = JsonConvert.SerializeObject(response); if (!tcs.TrySetResult(json)) { McpLog.Warn($"TCS for async command '{commandName}' was already completed"); } } catch (Exception ex) { McpLog.Error($"Error completing async command '{commandName}': {ex.Message}\n{ex.StackTrace}"); ReportAsyncFailure(commandName, tcs, ex); } } /// /// Report an error that occurred during async command execution. /// Completes the TaskCompletionSource with an error response. /// /// /// /// private static void ReportAsyncFailure(string commandName, TaskCompletionSource tcs, Exception ex) { McpLog.Error($"Error in async command '{commandName}': {ex.Message}\n{ex.StackTrace}"); var errorResponse = new { status = "error", error = ex.Message, command = commandName, stackTrace = ex.StackTrace }; string json; try { json = JsonConvert.SerializeObject(errorResponse); } catch (Exception serializationEx) { McpLog.Error($"Failed to serialize error response for '{commandName}': {serializationEx.Message}"); json = "{\"status\":\"error\",\"error\":\"Failed to complete command\"}"; } if (!tcs.TrySetResult(json)) { McpLog.Warn($"TCS for async command '{commandName}' was already completed when trying to report error"); } } } }