Fix stdio reloads (#402)

* First pass at MCP client refactor

* Restore original text instructions

Well most of them, I modified a few

* Move configurators to their own folder

It's less clusterd

* Remvoe override for Windsurf because we no longer need to use it

* Add Antigravity configs

Works like Windsurf, but it sucks ass

* Add some docs for properties

* Add comprehensive MCP client configurators documentation

* Add missing imports (#7)

* Handle Linux paths when unregistering CLI commands

* Construct a JSON error in a much more secure fashion

* Fix stdio auto-reconnect after domain reloads

We mirror what we've done with the HTTP/websocket connection

We also ensure the states from the stdio/HTTP connections are handled separately. Things now work as expected

* Fix ActiveMode to return resolved transport mode instead of preferred mode

The ActiveMode property now calls ResolvePreferredMode() to return the actual active transport mode rather than just the preferred mode setting.

* Minor improvements for stdio bridge

- Consolidated the !useHttp && isRunning checks into a single shouldResume flag.
- Wrapped the fire-and-forget StopAsync in a continuation that logs faults (matching the HTTP handler pattern).
- Wrapped StartAsync in a continuation that logs failures and only triggers the health check on success.

* Refactor TransportManager to use switch expressions and improve error handling

- Replace if-else chains with switch expressions for better readability and exhaustiveness checking
- Add GetClient() helper method to centralize client retrieval logic
- Wrap StopAsync in try-catch to log failures when stopping a failed transport
- Use client.TransportName instead of mode.ToString() for consistent naming in error messages
main
Marcus Sanatan 2025-11-27 19:33:26 -04:00 committed by GitHub
parent f94cb2460a
commit 17cd543fab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 235 additions and 109 deletions

View File

@ -11,6 +11,7 @@ namespace MCPForUnity.Editor.Constants
internal const string ValidationLevel = "MCPForUnity.ValidationLevel";
internal const string UnitySocketPort = "MCPForUnity.UnitySocketPort";
internal const string ResumeHttpAfterReload = "MCPForUnity.ResumeHttpAfterReload";
internal const string ResumeStdioAfterReload = "MCPForUnity.ResumeStdioAfterReload";
internal const string UvxPathOverride = "MCPForUnity.UvxPath";
internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath";

View File

@ -49,13 +49,21 @@ namespace MCPForUnity.Editor.Services
};
}
public bool IsRunning => _transportManager.GetState().IsConnected;
public bool IsRunning
{
get
{
var mode = ResolvePreferredMode();
return _transportManager.IsRunning(mode);
}
}
public int CurrentPort
{
get
{
var state = _transportManager.GetState();
var mode = ResolvePreferredMode();
var state = _transportManager.GetState(mode);
if (state.Port.HasValue)
{
return state.Port.Value;
@ -67,7 +75,7 @@ namespace MCPForUnity.Editor.Services
}
public bool IsAutoConnectMode => StdioBridgeHost.IsAutoConnectMode();
public TransportMode? ActiveMode => _transportManager.ActiveMode;
public TransportMode? ActiveMode => ResolvePreferredMode();
public async Task<bool> StartAsync()
{
@ -92,7 +100,8 @@ namespace MCPForUnity.Editor.Services
{
try
{
await _transportManager.StopAsync();
var mode = ResolvePreferredMode();
await _transportManager.StopAsync(mode);
}
catch (Exception ex)
{
@ -102,17 +111,17 @@ namespace MCPForUnity.Editor.Services
public async Task<BridgeVerificationResult> VerifyAsync()
{
var mode = _transportManager.ActiveMode ?? ResolvePreferredMode();
bool pingSucceeded = await _transportManager.VerifyAsync();
var state = _transportManager.GetState();
var mode = ResolvePreferredMode();
bool pingSucceeded = await _transportManager.VerifyAsync(mode);
var state = _transportManager.GetState(mode);
return BuildVerificationResult(state, mode, pingSucceeded);
}
public BridgeVerificationResult Verify(int port)
{
var mode = _transportManager.ActiveMode ?? ResolvePreferredMode();
bool pingSucceeded = _transportManager.VerifyAsync().GetAwaiter().GetResult();
var state = _transportManager.GetState();
var mode = ResolvePreferredMode();
bool pingSucceeded = _transportManager.VerifyAsync(mode).GetAwaiter().GetResult();
var state = _transportManager.GetState(mode);
if (mode == TransportMode.Stdio)
{

View File

@ -24,8 +24,8 @@ namespace MCPForUnity.Editor.Services
{
try
{
var bridge = MCPServiceLocator.Bridge;
bool shouldResume = bridge.IsRunning && bridge.ActiveMode == TransportMode.Http;
var transport = MCPServiceLocator.TransportManager;
bool shouldResume = transport.IsRunning(TransportMode.Http);
if (shouldResume)
{
@ -36,9 +36,9 @@ namespace MCPForUnity.Editor.Services
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload);
}
if (bridge.IsRunning)
if (shouldResume)
{
var stopTask = bridge.StopAsync();
var stopTask = transport.StopAsync(TransportMode.Http);
stopTask.ContinueWith(t =>
{
if (t.IsFaulted && t.Exception != null)
@ -59,7 +59,9 @@ namespace MCPForUnity.Editor.Services
bool resume = false;
try
{
resume = EditorPrefs.GetBool(EditorPrefKeys.ResumeHttpAfterReload, false);
// Only resume HTTP if it is still the selected transport.
bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true);
resume = useHttp && EditorPrefs.GetBool(EditorPrefKeys.ResumeHttpAfterReload, false);
if (resume)
{
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload);
@ -90,7 +92,7 @@ namespace MCPForUnity.Editor.Services
{
try
{
var startTask = MCPServiceLocator.Bridge.StartAsync();
var startTask = MCPServiceLocator.TransportManager.StartAsync(TransportMode.Http);
startTask.ContinueWith(t =>
{
if (t.IsFaulted)
@ -123,7 +125,7 @@ namespace MCPForUnity.Editor.Services
{
try
{
bool started = await MCPServiceLocator.Bridge.StartAsync();
bool started = await MCPServiceLocator.TransportManager.StartAsync(TransportMode.Http);
if (!started)
{
McpLog.Warn("Failed to resume HTTP MCP bridge after domain reload");

View File

@ -0,0 +1,104 @@
using System;
using UnityEditor;
using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services.Transport;
using MCPForUnity.Editor.Services.Transport.Transports;
namespace MCPForUnity.Editor.Services
{
/// <summary>
/// Ensures the legacy stdio bridge resumes after domain reloads, mirroring the HTTP handler.
/// </summary>
[InitializeOnLoad]
internal static class StdioBridgeReloadHandler
{
static StdioBridgeReloadHandler()
{
AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload;
AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload;
}
private static void OnBeforeAssemblyReload()
{
try
{
// Only persist resume intent when stdio is the active transport and the bridge is running.
bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true);
bool isRunning = MCPServiceLocator.TransportManager.IsRunning(TransportMode.Stdio);
bool shouldResume = !useHttp && isRunning;
if (shouldResume)
{
EditorPrefs.SetBool(EditorPrefKeys.ResumeStdioAfterReload, true);
// Stop only the stdio bridge; leave HTTP untouched if it is running concurrently.
var stopTask = MCPServiceLocator.TransportManager.StopAsync(TransportMode.Stdio);
stopTask.ContinueWith(t =>
{
if (t.IsFaulted && t.Exception != null)
{
McpLog.Warn($"Error stopping stdio bridge before reload: {t.Exception.GetBaseException()?.Message}");
}
}, System.Threading.Tasks.TaskScheduler.Default);
}
else
{
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeStdioAfterReload);
}
}
catch (Exception ex)
{
McpLog.Warn($"Failed to persist stdio reload flag: {ex.Message}");
}
}
private static void OnAfterAssemblyReload()
{
bool resume = false;
try
{
resume = EditorPrefs.GetBool(EditorPrefKeys.ResumeStdioAfterReload, false);
bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true);
resume = resume && !useHttp;
if (resume)
{
EditorPrefs.DeleteKey(EditorPrefKeys.ResumeStdioAfterReload);
}
}
catch (Exception ex)
{
McpLog.Warn($"Failed to read stdio reload flag: {ex.Message}");
}
if (!resume)
{
return;
}
// Restart via TransportManager so state stays in sync; if it fails (port busy), rely on UI to retry.
TryStartBridgeImmediate();
}
private static void TryStartBridgeImmediate()
{
var startTask = MCPServiceLocator.TransportManager.StartAsync(TransportMode.Stdio);
startTask.ContinueWith(t =>
{
if (t.IsFaulted)
{
var baseEx = t.Exception?.GetBaseException();
McpLog.Warn($"Failed to resume stdio bridge after reload: {baseEx?.Message}");
return;
}
if (!t.Result)
{
McpLog.Warn("Failed to resume stdio bridge after domain reload");
return;
}
MCPForUnity.Editor.Windows.MCPForUnityEditorWindow.RequestHealthVerification();
}, System.Threading.Tasks.TaskScheduler.Default);
}
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 6e603c72a87974cf5b495cd683165fbf
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -10,8 +10,10 @@ namespace MCPForUnity.Editor.Services.Transport
/// </summary>
public class TransportManager
{
private IMcpTransportClient _active;
private TransportMode? _activeMode;
private IMcpTransportClient _httpClient;
private IMcpTransportClient _stdioClient;
private TransportState _httpState = TransportState.Disconnected("http");
private TransportState _stdioState = TransportState.Disconnected("stdio");
private Func<IMcpTransportClient> _webSocketFactory;
private Func<IMcpTransportClient> _stdioFactory;
@ -22,8 +24,8 @@ namespace MCPForUnity.Editor.Services.Transport
() => new StdioTransportClient());
}
public IMcpTransportClient ActiveTransport => _active;
public TransportMode? ActiveMode => _activeMode;
public IMcpTransportClient ActiveTransport => null; // Deprecated single-transport accessor
public TransportMode? ActiveMode => null; // Deprecated single-transport accessor
public void Configure(
Func<IMcpTransportClient> webSocketFactory,
@ -33,68 +35,115 @@ namespace MCPForUnity.Editor.Services.Transport
_stdioFactory = stdioFactory ?? throw new ArgumentNullException(nameof(stdioFactory));
}
private IMcpTransportClient GetOrCreateClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient ??= _webSocketFactory(),
TransportMode.Stdio => _stdioClient ??= _stdioFactory(),
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
private IMcpTransportClient GetClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
public async Task<bool> StartAsync(TransportMode mode)
{
await StopAsync();
IMcpTransportClient client = GetOrCreateClient(mode);
IMcpTransportClient next = mode switch
{
TransportMode.Stdio => _stdioFactory(),
TransportMode.Http => _webSocketFactory(),
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode")
} ?? throw new InvalidOperationException($"Factory returned null for transport mode {mode}");
bool started = await next.StartAsync();
bool started = await client.StartAsync();
if (!started)
{
await next.StopAsync();
_active = null;
_activeMode = null;
return false;
}
_active = next;
_activeMode = mode;
return true;
}
public async Task StopAsync()
{
if (_active != null)
{
try
{
await _active.StopAsync();
await client.StopAsync();
}
catch (Exception ex)
{
McpLog.Warn($"Error while stopping transport {_active.TransportName}: {ex.Message}");
McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}");
}
finally
UpdateState(mode, TransportState.Disconnected(client.TransportName, "Failed to start"));
return false;
}
UpdateState(mode, client.State ?? TransportState.Connected(client.TransportName));
return true;
}
public async Task StopAsync(TransportMode? mode = null)
{
_active = null;
_activeMode = null;
async Task StopClient(IMcpTransportClient client, TransportMode clientMode)
{
if (client == null) return;
try { await client.StopAsync(); }
catch (Exception ex) { McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}"); }
finally { UpdateState(clientMode, TransportState.Disconnected(client.TransportName)); }
}
if (mode == null)
{
await StopClient(_httpClient, TransportMode.Http);
await StopClient(_stdioClient, TransportMode.Stdio);
return;
}
if (mode == TransportMode.Http)
{
await StopClient(_httpClient, TransportMode.Http);
}
else
{
await StopClient(_stdioClient, TransportMode.Stdio);
}
}
public async Task<bool> VerifyAsync()
public async Task<bool> VerifyAsync(TransportMode mode)
{
if (_active == null)
IMcpTransportClient client = GetClient(mode);
if (client == null)
{
return false;
}
return await _active.VerifyAsync();
bool ok = await client.VerifyAsync();
var state = client.State ?? TransportState.Disconnected(client.TransportName, "No state reported");
UpdateState(mode, state);
return ok;
}
public TransportState GetState()
public TransportState GetState(TransportMode mode)
{
if (_active == null)
return mode switch
{
return TransportState.Disconnected(_activeMode?.ToString()?.ToLowerInvariant() ?? "unknown", "Transport not started");
TransportMode.Http => _httpState,
TransportMode.Stdio => _stdioState,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
return _active.State ?? TransportState.Disconnected(_active.TransportName, "No state reported");
public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected;
private void UpdateState(TransportMode mode, TransportState state)
{
switch (mode)
{
case TransportMode.Http:
_httpState = state;
break;
case TransportMode.Stdio:
_stdioState = state;
break;
default:
throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode");
}
}
}

View File

@ -57,7 +57,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
private static int mainThreadId;
private static int currentUnityPort = 6400;
private static bool isAutoConnectMode = false;
private static bool shouldRestartAfterReload = false;
private const ulong MaxFrameBytes = 64UL * 1024 * 1024;
private const int FrameIOTimeoutMs = 30000;
@ -162,8 +161,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
}
}
EditorApplication.quitting += Stop;
AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload;
AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload;
EditorApplication.playModeStateChanged += _ =>
{
if (ShouldAutoStartBridge())
@ -406,10 +403,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
listenerTask = Task.Run(() => ListenerLoopAsync(cts.Token));
CommandRegistry.Initialize();
EditorApplication.update += ProcessCommands;
try { AssemblyReloadEvents.beforeAssemblyReload -= OnBeforeAssemblyReload; } catch { }
try { AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload; } catch { }
try { AssemblyReloadEvents.afterAssemblyReload -= OnAfterAssemblyReload; } catch { }
try { AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload; } catch { }
try { EditorApplication.quitting -= Stop; } catch { }
try { EditorApplication.quitting += Stop; } catch { }
heartbeatSeq++;
@ -470,8 +463,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
}
try { EditorApplication.update -= ProcessCommands; } catch { }
try { AssemblyReloadEvents.beforeAssemblyReload -= OnBeforeAssemblyReload; } catch { }
try { AssemblyReloadEvents.afterAssemblyReload -= OnAfterAssemblyReload; } catch { }
try { EditorApplication.quitting -= Stop; } catch { }
try
@ -1023,47 +1014,6 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
return false;
}
private static void OnBeforeAssemblyReload()
{
if (isRunning)
{
shouldRestartAfterReload = true;
}
try { Stop(); } catch { }
}
private static void OnAfterAssemblyReload()
{
WriteHeartbeat(false, "idle");
LogBreadcrumb("Idle");
bool shouldResume = ShouldAutoStartBridge() || shouldRestartAfterReload;
if (shouldRestartAfterReload)
{
shouldRestartAfterReload = false;
}
if (!shouldResume)
{
return;
}
// If we're not compiling, try to bring the bridge up immediately to avoid depending on editor focus.
if (!IsCompiling())
{
try
{
Start();
return; // Successful immediate start; no need to schedule a delayed retry
}
catch (Exception ex)
{
// Fall through to delayed retry if immediate start fails
McpLog.Warn($"Immediate STDIO bridge restart after reload failed: {ex.Message}");
}
}
// Fallback path when compiling or if immediate start failed
ScheduleInitRetry();
}
private static void WriteHeartbeat(bool reloading, string reason = null)
{