Fix stdio reloads (#402)

* First pass at MCP client refactor

* Restore original text instructions

Well most of them, I modified a few

* Move configurators to their own folder

It's less clusterd

* Remvoe override for Windsurf because we no longer need to use it

* Add Antigravity configs

Works like Windsurf, but it sucks ass

* Add some docs for properties

* Add comprehensive MCP client configurators documentation

* Add missing imports (#7)

* Handle Linux paths when unregistering CLI commands

* Construct a JSON error in a much more secure fashion

* Fix stdio auto-reconnect after domain reloads

We mirror what we've done with the HTTP/websocket connection

We also ensure the states from the stdio/HTTP connections are handled separately. Things now work as expected

* Fix ActiveMode to return resolved transport mode instead of preferred mode

The ActiveMode property now calls ResolvePreferredMode() to return the actual active transport mode rather than just the preferred mode setting.

* Minor improvements for stdio bridge

- Consolidated the !useHttp && isRunning checks into a single shouldResume flag.
- Wrapped the fire-and-forget StopAsync in a continuation that logs faults (matching the HTTP handler pattern).
- Wrapped StartAsync in a continuation that logs failures and only triggers the health check on success.

* Refactor TransportManager to use switch expressions and improve error handling

- Replace if-else chains with switch expressions for better readability and exhaustiveness checking
- Add GetClient() helper method to centralize client retrieval logic
- Wrap StopAsync in try-catch to log failures when stopping a failed transport
- Use client.TransportName instead of mode.ToString() for consistent naming in error messages
main
Marcus Sanatan 2025-11-27 19:33:26 -04:00 committed by GitHub
parent f94cb2460a
commit 17cd543fab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 235 additions and 109 deletions

View File

@ -11,6 +11,7 @@ namespace MCPForUnity.Editor.Constants
internal const string ValidationLevel = "MCPForUnity.ValidationLevel"; internal const string ValidationLevel = "MCPForUnity.ValidationLevel";
internal const string UnitySocketPort = "MCPForUnity.UnitySocketPort"; internal const string UnitySocketPort = "MCPForUnity.UnitySocketPort";
internal const string ResumeHttpAfterReload = "MCPForUnity.ResumeHttpAfterReload"; internal const string ResumeHttpAfterReload = "MCPForUnity.ResumeHttpAfterReload";
internal const string ResumeStdioAfterReload = "MCPForUnity.ResumeStdioAfterReload";
internal const string UvxPathOverride = "MCPForUnity.UvxPath"; internal const string UvxPathOverride = "MCPForUnity.UvxPath";
internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath"; internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath";

View File

@ -49,13 +49,21 @@ namespace MCPForUnity.Editor.Services
}; };
} }
public bool IsRunning => _transportManager.GetState().IsConnected; public bool IsRunning
{
get
{
var mode = ResolvePreferredMode();
return _transportManager.IsRunning(mode);
}
}
public int CurrentPort public int CurrentPort
{ {
get get
{ {
var state = _transportManager.GetState(); var mode = ResolvePreferredMode();
var state = _transportManager.GetState(mode);
if (state.Port.HasValue) if (state.Port.HasValue)
{ {
return state.Port.Value; return state.Port.Value;
@ -67,7 +75,7 @@ namespace MCPForUnity.Editor.Services
} }
public bool IsAutoConnectMode => StdioBridgeHost.IsAutoConnectMode(); public bool IsAutoConnectMode => StdioBridgeHost.IsAutoConnectMode();
public TransportMode? ActiveMode => _transportManager.ActiveMode; public TransportMode? ActiveMode => ResolvePreferredMode();
public async Task<bool> StartAsync() public async Task<bool> StartAsync()
{ {
@ -92,7 +100,8 @@ namespace MCPForUnity.Editor.Services
{ {
try try
{ {
await _transportManager.StopAsync(); var mode = ResolvePreferredMode();
await _transportManager.StopAsync(mode);
} }
catch (Exception ex) catch (Exception ex)
{ {
@ -102,17 +111,17 @@ namespace MCPForUnity.Editor.Services
public async Task<BridgeVerificationResult> VerifyAsync() public async Task<BridgeVerificationResult> VerifyAsync()
{ {
var mode = _transportManager.ActiveMode ?? ResolvePreferredMode(); var mode = ResolvePreferredMode();
bool pingSucceeded = await _transportManager.VerifyAsync(); bool pingSucceeded = await _transportManager.VerifyAsync(mode);
var state = _transportManager.GetState(); var state = _transportManager.GetState(mode);
return BuildVerificationResult(state, mode, pingSucceeded); return BuildVerificationResult(state, mode, pingSucceeded);
} }
public BridgeVerificationResult Verify(int port) public BridgeVerificationResult Verify(int port)
{ {
var mode = _transportManager.ActiveMode ?? ResolvePreferredMode(); var mode = ResolvePreferredMode();
bool pingSucceeded = _transportManager.VerifyAsync().GetAwaiter().GetResult(); bool pingSucceeded = _transportManager.VerifyAsync(mode).GetAwaiter().GetResult();
var state = _transportManager.GetState(); var state = _transportManager.GetState(mode);
if (mode == TransportMode.Stdio) if (mode == TransportMode.Stdio)
{ {

View File

@ -24,8 +24,8 @@ namespace MCPForUnity.Editor.Services
{ {
try try
{ {
var bridge = MCPServiceLocator.Bridge; var transport = MCPServiceLocator.TransportManager;
bool shouldResume = bridge.IsRunning && bridge.ActiveMode == TransportMode.Http; bool shouldResume = transport.IsRunning(TransportMode.Http);
if (shouldResume) if (shouldResume)
{ {
@ -36,9 +36,9 @@ namespace MCPForUnity.Editor.Services
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload); EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload);
} }
if (bridge.IsRunning) if (shouldResume)
{ {
var stopTask = bridge.StopAsync(); var stopTask = transport.StopAsync(TransportMode.Http);
stopTask.ContinueWith(t => stopTask.ContinueWith(t =>
{ {
if (t.IsFaulted && t.Exception != null) if (t.IsFaulted && t.Exception != null)
@ -59,7 +59,9 @@ namespace MCPForUnity.Editor.Services
bool resume = false; bool resume = false;
try try
{ {
resume = EditorPrefs.GetBool(EditorPrefKeys.ResumeHttpAfterReload, false); // Only resume HTTP if it is still the selected transport.
bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true);
resume = useHttp && EditorPrefs.GetBool(EditorPrefKeys.ResumeHttpAfterReload, false);
if (resume) if (resume)
{ {
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload); EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload);
@ -90,7 +92,7 @@ namespace MCPForUnity.Editor.Services
{ {
try try
{ {
var startTask = MCPServiceLocator.Bridge.StartAsync(); var startTask = MCPServiceLocator.TransportManager.StartAsync(TransportMode.Http);
startTask.ContinueWith(t => startTask.ContinueWith(t =>
{ {
if (t.IsFaulted) if (t.IsFaulted)
@ -123,7 +125,7 @@ namespace MCPForUnity.Editor.Services
{ {
try try
{ {
bool started = await MCPServiceLocator.Bridge.StartAsync(); bool started = await MCPServiceLocator.TransportManager.StartAsync(TransportMode.Http);
if (!started) if (!started)
{ {
McpLog.Warn("Failed to resume HTTP MCP bridge after domain reload"); McpLog.Warn("Failed to resume HTTP MCP bridge after domain reload");

View File

@ -0,0 +1,104 @@
using System;
using UnityEditor;
using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services.Transport;
using MCPForUnity.Editor.Services.Transport.Transports;
namespace MCPForUnity.Editor.Services
{
/// <summary>
/// Ensures the legacy stdio bridge resumes after domain reloads, mirroring the HTTP handler.
/// </summary>
[InitializeOnLoad]
internal static class StdioBridgeReloadHandler
{
static StdioBridgeReloadHandler()
{
AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload;
AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload;
}
private static void OnBeforeAssemblyReload()
{
try
{
// Only persist resume intent when stdio is the active transport and the bridge is running.
bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true);
bool isRunning = MCPServiceLocator.TransportManager.IsRunning(TransportMode.Stdio);
bool shouldResume = !useHttp && isRunning;
if (shouldResume)
{
EditorPrefs.SetBool(EditorPrefKeys.ResumeStdioAfterReload, true);
// Stop only the stdio bridge; leave HTTP untouched if it is running concurrently.
var stopTask = MCPServiceLocator.TransportManager.StopAsync(TransportMode.Stdio);
stopTask.ContinueWith(t =>
{
if (t.IsFaulted && t.Exception != null)
{
McpLog.Warn($"Error stopping stdio bridge before reload: {t.Exception.GetBaseException()?.Message}");
}
}, System.Threading.Tasks.TaskScheduler.Default);
}
else
{
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeStdioAfterReload);
}
}
catch (Exception ex)
{
McpLog.Warn($"Failed to persist stdio reload flag: {ex.Message}");
}
}
private static void OnAfterAssemblyReload()
{
bool resume = false;
try
{
resume = EditorPrefs.GetBool(EditorPrefKeys.ResumeStdioAfterReload, false);
bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true);
resume = resume && !useHttp;
if (resume)
{
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeStdioAfterReload);
}
}
catch (Exception ex)
{
McpLog.Warn($"Failed to read stdio reload flag: {ex.Message}");
}
if (!resume)
{
return;
}
// Restart via TransportManager so state stays in sync; if it fails (port busy), rely on UI to retry.
TryStartBridgeImmediate();
}
private static void TryStartBridgeImmediate()
{
var startTask = MCPServiceLocator.TransportManager.StartAsync(TransportMode.Stdio);
startTask.ContinueWith(t =>
{
if (t.IsFaulted)
{
var baseEx = t.Exception?.GetBaseException();
McpLog.Warn($"Failed to resume stdio bridge after reload: {baseEx?.Message}");
return;
}
if (!t.Result)
{
McpLog.Warn("Failed to resume stdio bridge after domain reload");
return;
}
MCPForUnity.Editor.Windows.MCPForUnityEditorWindow.RequestHealthVerification();
}, System.Threading.Tasks.TaskScheduler.Default);
}
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 6e603c72a87974cf5b495cd683165fbf
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -10,8 +10,10 @@ namespace MCPForUnity.Editor.Services.Transport
/// </summary> /// </summary>
public class TransportManager public class TransportManager
{ {
private IMcpTransportClient _active; private IMcpTransportClient _httpClient;
private TransportMode? _activeMode; private IMcpTransportClient _stdioClient;
private TransportState _httpState = TransportState.Disconnected("http");
private TransportState _stdioState = TransportState.Disconnected("stdio");
private Func<IMcpTransportClient> _webSocketFactory; private Func<IMcpTransportClient> _webSocketFactory;
private Func<IMcpTransportClient> _stdioFactory; private Func<IMcpTransportClient> _stdioFactory;
@ -22,8 +24,8 @@ namespace MCPForUnity.Editor.Services.Transport
() => new StdioTransportClient()); () => new StdioTransportClient());
} }
public IMcpTransportClient ActiveTransport => _active; public IMcpTransportClient ActiveTransport => null; // Deprecated single-transport accessor
public TransportMode? ActiveMode => _activeMode; public TransportMode? ActiveMode => null; // Deprecated single-transport accessor
public void Configure( public void Configure(
Func<IMcpTransportClient> webSocketFactory, Func<IMcpTransportClient> webSocketFactory,
@ -33,68 +35,115 @@ namespace MCPForUnity.Editor.Services.Transport
_stdioFactory = stdioFactory ?? throw new ArgumentNullException(nameof(stdioFactory)); _stdioFactory = stdioFactory ?? throw new ArgumentNullException(nameof(stdioFactory));
} }
public async Task<bool> StartAsync(TransportMode mode) private IMcpTransportClient GetOrCreateClient(TransportMode mode)
{ {
await StopAsync(); return mode switch
IMcpTransportClient next = mode switch
{ {
TransportMode.Stdio => _stdioFactory(), TransportMode.Http => _httpClient ??= _webSocketFactory(),
TransportMode.Http => _webSocketFactory(), TransportMode.Stdio => _stdioClient ??= _stdioFactory(),
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode") _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
} ?? throw new InvalidOperationException($"Factory returned null for transport mode {mode}"); };
bool started = await next.StartAsync();
if (!started)
{
await next.StopAsync();
_active = null;
_activeMode = null;
return false;
}
_active = next;
_activeMode = mode;
return true;
} }
public async Task StopAsync() private IMcpTransportClient GetClient(TransportMode mode)
{ {
if (_active != null) return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
public async Task<bool> StartAsync(TransportMode mode)
{
IMcpTransportClient client = GetOrCreateClient(mode);
bool started = await client.StartAsync();
if (!started)
{ {
try try
{ {
await _active.StopAsync(); await client.StopAsync();
} }
catch (Exception ex) catch (Exception ex)
{ {
McpLog.Warn($"Error while stopping transport {_active.TransportName}: {ex.Message}"); McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}");
}
finally
{
_active = null;
_activeMode = null;
} }
UpdateState(mode, TransportState.Disconnected(client.TransportName, "Failed to start"));
return false;
}
UpdateState(mode, client.State ?? TransportState.Connected(client.TransportName));
return true;
}
public async Task StopAsync(TransportMode? mode = null)
{
async Task StopClient(IMcpTransportClient client, TransportMode clientMode)
{
if (client == null) return;
try { await client.StopAsync(); }
catch (Exception ex) { McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}"); }
finally { UpdateState(clientMode, TransportState.Disconnected(client.TransportName)); }
}
if (mode == null)
{
await StopClient(_httpClient, TransportMode.Http);
await StopClient(_stdioClient, TransportMode.Stdio);
return;
}
if (mode == TransportMode.Http)
{
await StopClient(_httpClient, TransportMode.Http);
}
else
{
await StopClient(_stdioClient, TransportMode.Stdio);
} }
} }
public async Task<bool> VerifyAsync() public async Task<bool> VerifyAsync(TransportMode mode)
{ {
if (_active == null) IMcpTransportClient client = GetClient(mode);
if (client == null)
{ {
return false; return false;
} }
return await _active.VerifyAsync();
bool ok = await client.VerifyAsync();
var state = client.State ?? TransportState.Disconnected(client.TransportName, "No state reported");
UpdateState(mode, state);
return ok;
} }
public TransportState GetState() public TransportState GetState(TransportMode mode)
{ {
if (_active == null) return mode switch
{ {
return TransportState.Disconnected(_activeMode?.ToString()?.ToLowerInvariant() ?? "unknown", "Transport not started"); TransportMode.Http => _httpState,
} TransportMode.Stdio => _stdioState,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
return _active.State ?? TransportState.Disconnected(_active.TransportName, "No state reported"); public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected;
private void UpdateState(TransportMode mode, TransportState state)
{
switch (mode)
{
case TransportMode.Http:
_httpState = state;
break;
case TransportMode.Stdio:
_stdioState = state;
break;
default:
throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode");
}
} }
} }

View File

@ -57,7 +57,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
private static int mainThreadId; private static int mainThreadId;
private static int currentUnityPort = 6400; private static int currentUnityPort = 6400;
private static bool isAutoConnectMode = false; private static bool isAutoConnectMode = false;
private static bool shouldRestartAfterReload = false;
private const ulong MaxFrameBytes = 64UL * 1024 * 1024; private const ulong MaxFrameBytes = 64UL * 1024 * 1024;
private const int FrameIOTimeoutMs = 30000; private const int FrameIOTimeoutMs = 30000;
@ -162,8 +161,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
} }
} }
EditorApplication.quitting += Stop; EditorApplication.quitting += Stop;
AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload;
AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload;
EditorApplication.playModeStateChanged += _ => EditorApplication.playModeStateChanged += _ =>
{ {
if (ShouldAutoStartBridge()) if (ShouldAutoStartBridge())
@ -406,10 +403,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
listenerTask = Task.Run(() => ListenerLoopAsync(cts.Token)); listenerTask = Task.Run(() => ListenerLoopAsync(cts.Token));
CommandRegistry.Initialize(); CommandRegistry.Initialize();
EditorApplication.update += ProcessCommands; EditorApplication.update += ProcessCommands;
try { AssemblyReloadEvents.beforeAssemblyReload -= OnBeforeAssemblyReload; } catch { }
try { AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload; } catch { }
try { AssemblyReloadEvents.afterAssemblyReload -= OnAfterAssemblyReload; } catch { }
try { AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload; } catch { }
try { EditorApplication.quitting -= Stop; } catch { } try { EditorApplication.quitting -= Stop; } catch { }
try { EditorApplication.quitting += Stop; } catch { } try { EditorApplication.quitting += Stop; } catch { }
heartbeatSeq++; heartbeatSeq++;
@ -470,8 +463,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
} }
try { EditorApplication.update -= ProcessCommands; } catch { } try { EditorApplication.update -= ProcessCommands; } catch { }
try { AssemblyReloadEvents.beforeAssemblyReload -= OnBeforeAssemblyReload; } catch { }
try { AssemblyReloadEvents.afterAssemblyReload -= OnAfterAssemblyReload; } catch { }
try { EditorApplication.quitting -= Stop; } catch { } try { EditorApplication.quitting -= Stop; } catch { }
try try
@ -1023,47 +1014,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
return false; return false;
} }
private static void OnBeforeAssemblyReload()
{
if (isRunning)
{
shouldRestartAfterReload = true;
}
try { Stop(); } catch { }
}
private static void OnAfterAssemblyReload()
{
WriteHeartbeat(false, "idle");
LogBreadcrumb("Idle");
bool shouldResume = ShouldAutoStartBridge() || shouldRestartAfterReload;
if (shouldRestartAfterReload)
{
shouldRestartAfterReload = false;
}
if (!shouldResume)
{
return;
}
// If we're not compiling, try to bring the bridge up immediately to avoid depending on editor focus.
if (!IsCompiling())
{
try
{
Start();
return; // Successful immediate start; no need to schedule a delayed retry
}
catch (Exception ex)
{
// Fall through to delayed retry if immediate start fails
McpLog.Warn($"Immediate STDIO bridge restart after reload failed: {ex.Message}");
}
}
// Fallback path when compiling or if immediate start failed
ScheduleInitRetry();
}
private static void WriteHeartbeat(bool reloading, string reason = null) private static void WriteHeartbeat(bool reloading, string reason = null)
{ {