Remote server auth (#644)

* Disable the gloabl default to first session when hosting remotely

* Remove calls to /plugin/sessions

The newer /api/instances covers that data, and we want to remove these "expose all" endpoints

* Disable CLI routes when running in remote hosted mode

* Update server README

* feat: add API key authentication support for remote-hosted HTTP transport

- Add API key field to connection UI (visible only in HTTP Remote mode)
- Add "Get API Key" and "Clear" buttons with login URL retrieval
- Include X-API-Key header in WebSocket connections when configured
- Add API key to CLI commands (mcp add, claude mcp add) when set
- Update config.json generation to include headers with API key
- Add API key validation service with caching and configurable endpoints
- Add /api/auth/login-url endpoint

* feat: add environment variable support for HTTP remote hosted mode

- Add UNITY_MCP_HTTP_REMOTE_HOSTED environment variable as alternative to --http-remote-hosted flag
- Accept "true", "1", or "yes" values (case-insensitive)
- Update CLI help text to document environment variable option

* feat: add user isolation enforcement for remote-hosted mode session listing

- Raise ValueError when list_sessions() called without user_id in remote-hosted mode
- Add comprehensive integration tests for multi-user session isolation
- Add unit tests for PluginRegistry user-scoped session filtering
- Verify cross-user isolation with same project hash
- Test unity_instances resource and set_active_instance user filtering

* feat: add comprehensive integration tests for API key authentication

- Add ApiKeyService tests covering validation, caching, retries, and singleton lifecycle
- Add startup config validation tests for remote-hosted mode requirements
- Test cache hit/miss scenarios, TTL expiration, and manual invalidation
- Test transient failure handling (5xx, timeouts, connection errors) with retry logic
- Test service token header injection and empty key fast-path validation
- Test startup validation requiring

* test: add autouse fixture to restore config state after startup validation tests

Ensures test isolation for config-dependent integration tests

* feat: skip user_id resolution in non-remote-hosted mode

Prevents unnecessary API key validation when not in remote-hosted mode

* test: add missing mock attributes to instance routing tests

- Add client_id to test context mock in set_active_instance test
- Add get_state mock to context in global instance routing test

* Fix broken telemetry test

* Add comprehensive API key authentication documentation

- Add user guide covering configuration, setup, and troubleshooting
- Add architecture reference documenting internal design and request flows

* Add remote-hosted mode and API key authentication documentation to server README

* Update reference doc for Docker Hub

* Specify exception being caught

* Ensure caplog handler cleanup in telemetry queue worker test

* Use NoUnitySessionError instead of RuntimeError in session isolation test

* Remove unusued monkeypatch arg

* Use more obviously fake API keys

* Reject connections when ApiKeyService is not initialized in remote-hosted mode

- Validate that user_id is present after successful key validation
- Expand transient error detection to include timeout and service errors
- Use consistent 1013 status code for retryable auth failures

* Accept "on" for UNITY_MCP_HTTP_REMOTE_HOSTED env var

Consistent with repo

* Invalidate cached login URL when HTTP base URL changes

* Pass API key as parameter instead of reading from EditorPrefs in RegisterWithCapturedValues

* Cache API key in field instead of reading from EditorPrefs on each reconnection

* Align markdown table formatting in remote server auth documentation

* Minor fixes

* security: Sanitize API key values in shell commands and fix minor issues

Add SanitizeShellHeaderValue() method to escape special shell characters (", \, `, $, !) in API keys before including them in shell command arguments. Apply sanitization to all three locations where API keys are embedded in shell commands (two in RegisterWithCapturedValues, one in GetManualInstructions).

Also fix deprecated passwordCharacter property (now maskChar) and improve exception logging in _resolve_user_id_from_request

* Consolidate duplicate instance selection error messages into InstanceSelectionRequiredError class

Add InstanceSelectionRequiredError exception class with centralized error messages (_SELECTION_REQUIRED and _MULTIPLE_INSTANCES). Replace 4 duplicate RuntimeError raises with new exception type. Update tests to catch InstanceSelectionRequiredError instead of RuntimeError.

* Replace hardcoded "X-API-Key" strings with AuthConstants.ApiKeyHeader constant across C# and Python codebases

Add AuthConstants class in C# and API_KEY_HEADER constant in Python to centralize the API key header name definition. Update all 8 locations where "X-API-Key" was hardcoded (4 in C#, 4 in Python) to use the new constants instead.

* Fix imports

* Filter session listing by user_id in all code paths to prevent cross-user session access

Remove conditional logic that only filtered sessions by user_id in remote-hosted mode. Now all session listings are filtered by user_id regardless of hosting mode, ensuring users can only see and interact with their own sessions.

* Consolidate get_session_id_by_hash methods into single method with optional user_id parameter

Merge get_session_id_by_hash and get_session_id_by_user_hash into a single method that accepts an optional user_id parameter. Update all call sites to use the unified method signature with user_id as the second parameter. Update tests and documentation to reflect the simplified API.

* Add environment variable support for project-scoped-tools flag [skip ci]

Support UNITY_MCP_PROJECT_SCOPED_TOOLS environment variable as alternative to --project-scoped-tools command line flag. Accept "true", "1", "yes", or "on" as truthy values (case-insensitive). Update help text to document the environment variable option.

* Fix Python tests

* Update validation logic to only require API key validation URL when both http_remote_hosted is enabled AND transport mode is "http", preventing false validation errors in stdio mode.

* Update Server/src/main.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Refactor HTTP transport configuration to support separate local and remote URLs

Split HTTP transport into HttpLocal and HttpRemote modes with separate EditorPrefs storage (HttpBaseUrl and HttpRemoteBaseUrl).

Add HttpEndpointUtility methods to get/save local and remote URLs independently, and introduce IsRemoteScope() and GetCurrentServerTransport() helpers to centralize 3-way transport determination (Stdio/Http/HttpRemote). Update all client configuration code to distinguish between local and remote HTTP

* Only include API key headers in HTTP/WebSocket configuration when in remote-hosted mode

Update all locations where API key headers are added to HTTP/WebSocket configurations to check HttpEndpointUtility.IsRemoteScope() or serverTransport == HttpRemote before including the API key. This prevents local HTTP mode from unnecessarily including API key headers in shell commands, config JSON, and WebSocket connections.

* Hide Manual Server Launch foldout when not in HTTP Local mode

* Fix failing test

* Improve error messaging and API key validation for HTTP Remote transport

Add detailed error messages to WebSocket connection failures that guide users to check server URL, server status, and API key validity. Store error state in TransportState for propagation to UI. Disable "Start Session" button when HTTP Remote mode is selected without an API key, with tooltip explaining requirement. Display error dialog on connection failure with specific error message from transport state. Update connection

* Add missing .meta file

* Store transport mode in ServerConfig instead of environment variable

* Add autouse fixture to restore global config state between tests

Add restore_global_config fixture in conftest.py that automatically saves and restores global config attributes and UNITY_MCP_TRANSPORT environment variable between tests. Update integration tests to use monkeypatch.setattr on config.transport_mode instead of monkeypatch.setenv to prevent test pollution and ensure clean state isolation.

* Fix startup

* Replace _current_transport() calls with direct config.transport_mode access

* Minor cleanup

* Add integration tests for HTTP transport authentication behavior

Verify that HTTP local mode allows requests without user_id while HTTP remote-hosted mode rejects them with auth_required error.

* Add smoke tests for transport routing paths across HTTP local, HTTP remote, and stdio modes

Verify that HTTP local routes through PluginHub without user_id, HTTP remote routes through PluginHub with user_id, and stdio calls legacy send function with instance_id. Each test uses monkeypatch to configure transport mode and mock appropriate transport layer functions.

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
main
Marcus Sanatan 2026-01-30 18:39:21 -04:00 committed by GitHub
parent 8ee9700327
commit 664a43b76c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 3771 additions and 489 deletions

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 14a4b9a7f749248d496466c2a3a53e56
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -155,9 +155,19 @@ namespace MCPForUnity.Editor.Clients
client.configuredTransport = Models.ConfiguredTransport.Stdio; client.configuredTransport = Models.ConfiguredTransport.Stdio;
} }
else if (!string.IsNullOrEmpty(configuredUrl)) else if (!string.IsNullOrEmpty(configuredUrl))
{
// Distinguish HTTP Local from HTTP Remote by matching against both URLs
string localRpcUrl = HttpEndpointUtility.GetLocalMcpRpcUrl();
string remoteRpcUrl = HttpEndpointUtility.GetRemoteMcpRpcUrl();
if (!string.IsNullOrEmpty(remoteRpcUrl) && UrlsEqual(configuredUrl, remoteRpcUrl))
{
client.configuredTransport = Models.ConfiguredTransport.HttpRemote;
}
else
{ {
client.configuredTransport = Models.ConfiguredTransport.Http; client.configuredTransport = Models.ConfiguredTransport.Http;
} }
}
else else
{ {
client.configuredTransport = Models.ConfiguredTransport.Unknown; client.configuredTransport = Models.ConfiguredTransport.Unknown;
@ -173,6 +183,7 @@ namespace MCPForUnity.Editor.Clients
} }
else if (!string.IsNullOrEmpty(configuredUrl)) else if (!string.IsNullOrEmpty(configuredUrl))
{ {
// Match against the active scope's URL
string expectedUrl = HttpEndpointUtility.GetMcpRpcUrl(); string expectedUrl = HttpEndpointUtility.GetMcpRpcUrl();
matches = UrlsEqual(configuredUrl, expectedUrl); matches = UrlsEqual(configuredUrl, expectedUrl);
} }
@ -189,9 +200,7 @@ namespace MCPForUnity.Editor.Clients
if (result == "Configured successfully") if (result == "Configured successfully")
{ {
client.SetStatus(McpStatus.Configured); client.SetStatus(McpStatus.Configured);
// Update transport after rewrite based on current server setting client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
} }
else else
{ {
@ -220,9 +229,7 @@ namespace MCPForUnity.Editor.Clients
if (result == "Configured successfully") if (result == "Configured successfully")
{ {
client.SetStatus(McpStatus.Configured); client.SetStatus(McpStatus.Configured);
// Set transport based on current server setting client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
} }
else else
{ {
@ -271,9 +278,18 @@ namespace MCPForUnity.Editor.Clients
{ {
// Determine and set the configured transport type // Determine and set the configured transport type
if (!string.IsNullOrEmpty(url)) if (!string.IsNullOrEmpty(url))
{
// Distinguish HTTP Local from HTTP Remote
string remoteRpcUrl = HttpEndpointUtility.GetRemoteMcpRpcUrl();
if (!string.IsNullOrEmpty(remoteRpcUrl) && UrlsEqual(url, remoteRpcUrl))
{
client.configuredTransport = Models.ConfiguredTransport.HttpRemote;
}
else
{ {
client.configuredTransport = Models.ConfiguredTransport.Http; client.configuredTransport = Models.ConfiguredTransport.Http;
} }
}
else if (args != null && args.Length > 0) else if (args != null && args.Length > 0)
{ {
client.configuredTransport = Models.ConfiguredTransport.Stdio; client.configuredTransport = Models.ConfiguredTransport.Stdio;
@ -286,6 +302,7 @@ namespace MCPForUnity.Editor.Clients
bool matches = false; bool matches = false;
if (!string.IsNullOrEmpty(url)) if (!string.IsNullOrEmpty(url))
{ {
// Match against the active scope's URL
matches = UrlsEqual(url, HttpEndpointUtility.GetMcpRpcUrl()); matches = UrlsEqual(url, HttpEndpointUtility.GetMcpRpcUrl());
} }
else if (args != null && args.Length > 0) else if (args != null && args.Length > 0)
@ -313,9 +330,7 @@ namespace MCPForUnity.Editor.Clients
if (result == "Configured successfully") if (result == "Configured successfully")
{ {
client.SetStatus(McpStatus.Configured); client.SetStatus(McpStatus.Configured);
// Update transport after rewrite based on current server setting client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
} }
else else
{ {
@ -344,9 +359,7 @@ namespace MCPForUnity.Editor.Clients
if (result == "Configured successfully") if (result == "Configured successfully")
{ {
client.SetStatus(McpStatus.Configured); client.SetStatus(McpStatus.Configured);
// Set transport based on current server setting client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
} }
else else
{ {
@ -468,9 +481,13 @@ namespace MCPForUnity.Editor.Clients
bool registeredWithStdio = getStdout.Contains("Type: stdio", StringComparison.OrdinalIgnoreCase); bool registeredWithStdio = getStdout.Contains("Type: stdio", StringComparison.OrdinalIgnoreCase);
// Set the configured transport based on what we detected // Set the configured transport based on what we detected
// For HTTP, we can't distinguish local/remote from CLI output alone,
// so infer from the current scope setting when HTTP is detected.
if (registeredWithHttp) if (registeredWithHttp)
{ {
client.configuredTransport = Models.ConfiguredTransport.Http; client.configuredTransport = HttpEndpointUtility.IsRemoteScope()
? Models.ConfiguredTransport.HttpRemote
: Models.ConfiguredTransport.Http;
} }
else if (registeredWithStdio) else if (registeredWithStdio)
{ {
@ -481,7 +498,7 @@ namespace MCPForUnity.Editor.Clients
client.configuredTransport = Models.ConfiguredTransport.Unknown; client.configuredTransport = Models.ConfiguredTransport.Unknown;
} }
// Check for transport mismatch // Check for transport mismatch (3-way: Stdio, Http, HttpRemote)
bool hasTransportMismatch = (currentUseHttp && registeredWithStdio) || (!currentUseHttp && registeredWithHttp); bool hasTransportMismatch = (currentUseHttp && registeredWithStdio) || (!currentUseHttp && registeredWithHttp);
// For stdio transport, also check package version // For stdio transport, also check package version
@ -575,7 +592,9 @@ namespace MCPForUnity.Editor.Clients
public void ConfigureWithCapturedValues( public void ConfigureWithCapturedValues(
string projectDir, string claudePath, string pathPrepend, string projectDir, string claudePath, string pathPrepend,
bool useHttpTransport, string httpUrl, bool useHttpTransport, string httpUrl,
string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh) string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh,
string apiKey,
Models.ConfiguredTransport serverTransport)
{ {
if (client.status == McpStatus.Configured) if (client.status == McpStatus.Configured)
{ {
@ -584,7 +603,8 @@ namespace MCPForUnity.Editor.Clients
else else
{ {
RegisterWithCapturedValues(projectDir, claudePath, pathPrepend, RegisterWithCapturedValues(projectDir, claudePath, pathPrepend,
useHttpTransport, httpUrl, uvxPath, gitUrl, packageName, shouldForceRefresh); useHttpTransport, httpUrl, uvxPath, gitUrl, packageName, shouldForceRefresh,
apiKey, serverTransport);
} }
} }
@ -594,7 +614,9 @@ namespace MCPForUnity.Editor.Clients
private void RegisterWithCapturedValues( private void RegisterWithCapturedValues(
string projectDir, string claudePath, string pathPrepend, string projectDir, string claudePath, string pathPrepend,
bool useHttpTransport, string httpUrl, bool useHttpTransport, string httpUrl,
string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh) string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh,
string apiKey,
Models.ConfiguredTransport serverTransport)
{ {
if (string.IsNullOrEmpty(claudePath)) if (string.IsNullOrEmpty(claudePath))
{ {
@ -603,9 +625,18 @@ namespace MCPForUnity.Editor.Clients
string args; string args;
if (useHttpTransport) if (useHttpTransport)
{
// Only include API key header for remote-hosted mode
if (serverTransport == Models.ConfiguredTransport.HttpRemote && !string.IsNullOrEmpty(apiKey))
{
string safeKey = SanitizeShellHeaderValue(apiKey);
args = $"mcp add --transport http UnityMCP {httpUrl} --header \"{AuthConstants.ApiKeyHeader}: {safeKey}\"";
}
else
{ {
args = $"mcp add --transport http UnityMCP {httpUrl}"; args = $"mcp add --transport http UnityMCP {httpUrl}";
} }
}
else else
{ {
// Note: --reinstall is not supported by uvx, use --no-cache --refresh instead // Note: --reinstall is not supported by uvx, use --no-cache --refresh instead
@ -626,7 +657,7 @@ namespace MCPForUnity.Editor.Clients
McpLog.Info($"Successfully registered with Claude Code using {(useHttpTransport ? "HTTP" : "stdio")} transport."); McpLog.Info($"Successfully registered with Claude Code using {(useHttpTransport ? "HTTP" : "stdio")} transport.");
client.SetStatus(McpStatus.Configured); client.SetStatus(McpStatus.Configured);
client.configuredTransport = useHttpTransport ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio; client.configuredTransport = serverTransport;
} }
/// <summary> /// <summary>
@ -664,8 +695,25 @@ namespace MCPForUnity.Editor.Clients
if (useHttpTransport) if (useHttpTransport)
{ {
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl(); string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
// Only include API key header for remote-hosted mode
if (HttpEndpointUtility.IsRemoteScope())
{
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
if (!string.IsNullOrEmpty(apiKey))
{
string safeKey = SanitizeShellHeaderValue(apiKey);
args = $"mcp add --transport http UnityMCP {httpUrl} --header \"{AuthConstants.ApiKeyHeader}: {safeKey}\"";
}
else
{
args = $"mcp add --transport http UnityMCP {httpUrl}"; args = $"mcp add --transport http UnityMCP {httpUrl}";
} }
}
else
{
args = $"mcp add --transport http UnityMCP {httpUrl}";
}
}
else else
{ {
var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts(); var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts();
@ -715,7 +763,7 @@ namespace MCPForUnity.Editor.Clients
// Set status to Configured immediately after successful registration // Set status to Configured immediately after successful registration
// The UI will trigger an async verification check separately to avoid blocking // The UI will trigger an async verification check separately to avoid blocking
client.SetStatus(McpStatus.Configured); client.SetStatus(McpStatus.Configured);
client.configuredTransport = useHttpTransport ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio; client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
} }
private void Unregister() private void Unregister()
@ -757,8 +805,15 @@ namespace MCPForUnity.Editor.Clients
if (useHttpTransport) if (useHttpTransport)
{ {
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl(); string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
// Only include API key header for remote-hosted mode
string headerArg = "";
if (HttpEndpointUtility.IsRemoteScope())
{
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
headerArg = !string.IsNullOrEmpty(apiKey) ? $" --header \"{AuthConstants.ApiKeyHeader}: {SanitizeShellHeaderValue(apiKey)}\"" : "";
}
return "# Register the MCP server with Claude Code:\n" + return "# Register the MCP server with Claude Code:\n" +
$"claude mcp add --transport http UnityMCP {httpUrl}\n\n" + $"claude mcp add --transport http UnityMCP {httpUrl}{headerArg}\n\n" +
"# Unregister the MCP server:\n" + "# Unregister the MCP server:\n" +
"claude mcp remove UnityMCP\n\n" + "claude mcp remove UnityMCP\n\n" +
"# List registered servers:\n" + "# List registered servers:\n" +
@ -790,6 +845,37 @@ namespace MCPForUnity.Editor.Clients
"Restart Claude Code" "Restart Claude Code"
}; };
/// <summary>
/// Sanitizes a value for safe inclusion inside a double-quoted shell argument.
/// Escapes characters that are special within double quotes (", \, `, $, !)
/// to prevent shell injection or argument splitting.
/// </summary>
private static string SanitizeShellHeaderValue(string value)
{
if (string.IsNullOrEmpty(value))
return value;
var sb = new System.Text.StringBuilder(value.Length);
foreach (char c in value)
{
switch (c)
{
case '"':
case '\\':
case '`':
case '$':
case '!':
sb.Append('\\');
sb.Append(c);
break;
default:
sb.Append(c);
break;
}
}
return sb.ToString();
}
/// <summary> /// <summary>
/// Extracts the package source (--from argument value) from claude mcp get output. /// Extracts the package source (--from argument value) from claude mcp get output.
/// The output format includes args like: --from "mcpforunityserver==9.0.1" /// The output format includes args like: --from "mcpforunityserver==9.0.1"

View File

@ -0,0 +1,10 @@
namespace MCPForUnity.Editor.Constants
{
/// <summary>
/// Protocol-level constants for API key authentication.
/// </summary>
internal static class AuthConstants
{
internal const string ApiKeyHeader = "X-API-Key";
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 96844bc39e9a94cf18b18f8127f3854f
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -24,6 +24,7 @@ namespace MCPForUnity.Editor.Constants
internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath"; internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath";
internal const string HttpBaseUrl = "MCPForUnity.HttpUrl"; internal const string HttpBaseUrl = "MCPForUnity.HttpUrl";
internal const string HttpRemoteBaseUrl = "MCPForUnity.HttpRemoteUrl";
internal const string SessionId = "MCPForUnity.SessionId"; internal const string SessionId = "MCPForUnity.SessionId";
internal const string WebSocketUrlOverride = "MCPForUnity.WebSocketUrl"; internal const string WebSocketUrlOverride = "MCPForUnity.WebSocketUrl";
internal const string GitUrlOverride = "MCPForUnity.GitUrlOverride"; internal const string GitUrlOverride = "MCPForUnity.GitUrlOverride";
@ -55,5 +56,7 @@ namespace MCPForUnity.Editor.Constants
internal const string TelemetryDisabled = "MCPForUnity.TelemetryDisabled"; internal const string TelemetryDisabled = "MCPForUnity.TelemetryDisabled";
internal const string CustomerUuid = "MCPForUnity.CustomerUUID"; internal const string CustomerUuid = "MCPForUnity.CustomerUUID";
internal const string ApiKey = "MCPForUnity.ApiKey";
} }
} }

View File

@ -71,6 +71,26 @@ namespace MCPForUnity.Editor.Helpers
if (unity["command"] != null) unity.Remove("command"); if (unity["command"] != null) unity.Remove("command");
if (unity["args"] != null) unity.Remove("args"); if (unity["args"] != null) unity.Remove("args");
// Only include API key header for remote-hosted mode
if (HttpEndpointUtility.IsRemoteScope())
{
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
if (!string.IsNullOrEmpty(apiKey))
{
var headers = new JObject { [AuthConstants.ApiKeyHeader] = apiKey };
unity["headers"] = headers;
}
else
{
if (unity["headers"] != null) unity.Remove("headers");
}
}
else
{
// Local HTTP doesn't use API keys; remove any stale headers
if (unity["headers"] != null) unity.Remove("headers");
}
if (isVSCode) if (isVSCode)
{ {
unity["type"] = "http"; unity["type"] = "http";

View File

@ -1,5 +1,7 @@
using System; using System;
using MCPForUnity.Editor.Constants; using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Models;
using MCPForUnity.Editor.Services;
using UnityEditor; using UnityEditor;
namespace MCPForUnity.Editor.Helpers namespace MCPForUnity.Editor.Helpers
@ -8,38 +10,113 @@ namespace MCPForUnity.Editor.Helpers
/// Helper methods for managing HTTP endpoint URLs used by the MCP bridge. /// Helper methods for managing HTTP endpoint URLs used by the MCP bridge.
/// Ensures the stored value is always the base URL (without trailing path), /// Ensures the stored value is always the base URL (without trailing path),
/// and provides convenience accessors for specific endpoints. /// and provides convenience accessors for specific endpoints.
///
/// HTTP Local and HTTP Remote use separate EditorPrefs keys so that switching
/// between scopes does not overwrite the other scope's URL.
/// </summary> /// </summary>
public static class HttpEndpointUtility public static class HttpEndpointUtility
{ {
private const string PrefKey = EditorPrefKeys.HttpBaseUrl; private const string LocalPrefKey = EditorPrefKeys.HttpBaseUrl;
private const string DefaultBaseUrl = "http://localhost:8080"; private const string RemotePrefKey = EditorPrefKeys.HttpRemoteBaseUrl;
private const string DefaultLocalBaseUrl = "http://localhost:8080";
private const string DefaultRemoteBaseUrl = "";
/// <summary> /// <summary>
/// Returns the normalized base URL currently stored in EditorPrefs. /// Returns the normalized base URL for the currently active HTTP scope.
/// If the scope is "remote", returns the remote URL; otherwise returns the local URL.
/// </summary> /// </summary>
public static string GetBaseUrl() public static string GetBaseUrl()
{ {
string stored = EditorPrefs.GetString(PrefKey, DefaultBaseUrl); return IsRemoteScope() ? GetRemoteBaseUrl() : GetLocalBaseUrl();
return NormalizeBaseUrl(stored);
} }
/// <summary> /// <summary>
/// Saves a user-provided URL after normalizing it to a base form. /// Saves a user-provided URL to the currently active HTTP scope's pref.
/// </summary> /// </summary>
public static void SaveBaseUrl(string userValue) public static void SaveBaseUrl(string userValue)
{ {
string normalized = NormalizeBaseUrl(userValue); if (IsRemoteScope())
EditorPrefs.SetString(PrefKey, normalized); {
SaveRemoteBaseUrl(userValue);
}
else
{
SaveLocalBaseUrl(userValue);
}
} }
/// <summary> /// <summary>
/// Builds the JSON-RPC endpoint used by FastMCP clients (base + /mcp). /// Returns the normalized local HTTP base URL (always reads local pref).
/// </summary>
public static string GetLocalBaseUrl()
{
string stored = EditorPrefs.GetString(LocalPrefKey, DefaultLocalBaseUrl);
return NormalizeBaseUrl(stored, DefaultLocalBaseUrl);
}
/// <summary>
/// Saves a user-provided URL to the local HTTP pref.
/// </summary>
public static void SaveLocalBaseUrl(string userValue)
{
string normalized = NormalizeBaseUrl(userValue, DefaultLocalBaseUrl);
EditorPrefs.SetString(LocalPrefKey, normalized);
}
/// <summary>
/// Returns the normalized remote HTTP base URL (always reads remote pref).
/// Returns empty string if no remote URL is configured.
/// </summary>
public static string GetRemoteBaseUrl()
{
string stored = EditorPrefs.GetString(RemotePrefKey, DefaultRemoteBaseUrl);
if (string.IsNullOrWhiteSpace(stored))
{
return DefaultRemoteBaseUrl;
}
return NormalizeBaseUrl(stored, DefaultRemoteBaseUrl);
}
/// <summary>
/// Saves a user-provided URL to the remote HTTP pref.
/// </summary>
public static void SaveRemoteBaseUrl(string userValue)
{
if (string.IsNullOrWhiteSpace(userValue))
{
EditorPrefs.SetString(RemotePrefKey, DefaultRemoteBaseUrl);
return;
}
string normalized = NormalizeBaseUrl(userValue, DefaultRemoteBaseUrl);
EditorPrefs.SetString(RemotePrefKey, normalized);
}
/// <summary>
/// Builds the JSON-RPC endpoint for the currently active scope (base + /mcp).
/// </summary> /// </summary>
public static string GetMcpRpcUrl() public static string GetMcpRpcUrl()
{ {
return AppendPathSegment(GetBaseUrl(), "mcp"); return AppendPathSegment(GetBaseUrl(), "mcp");
} }
/// <summary>
/// Builds the local JSON-RPC endpoint (local base + /mcp).
/// </summary>
public static string GetLocalMcpRpcUrl()
{
return AppendPathSegment(GetLocalBaseUrl(), "mcp");
}
/// <summary>
/// Builds the remote JSON-RPC endpoint (remote base + /mcp).
/// Returns empty string if no remote URL is configured.
/// </summary>
public static string GetRemoteMcpRpcUrl()
{
string remoteBase = GetRemoteBaseUrl();
return string.IsNullOrEmpty(remoteBase) ? string.Empty : AppendPathSegment(remoteBase, "mcp");
}
/// <summary> /// <summary>
/// Builds the endpoint used when POSTing custom-tool registration payloads. /// Builds the endpoint used when POSTing custom-tool registration payloads.
/// </summary> /// </summary>
@ -48,14 +125,35 @@ namespace MCPForUnity.Editor.Helpers
return AppendPathSegment(GetBaseUrl(), "register-tools"); return AppendPathSegment(GetBaseUrl(), "register-tools");
} }
/// <summary>
/// Returns true if the active HTTP transport scope is "remote".
/// </summary>
public static bool IsRemoteScope()
{
string scope = EditorConfigurationCache.Instance.HttpTransportScope;
return string.Equals(scope, "remote", StringComparison.OrdinalIgnoreCase);
}
/// <summary>
/// Returns the <see cref="ConfiguredTransport"/> that matches the current server-side
/// transport selection (Stdio, Http, or HttpRemote).
/// Centralises the 3-way determination so callers avoid duplicated logic.
/// </summary>
public static ConfiguredTransport GetCurrentServerTransport()
{
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
if (!useHttp) return ConfiguredTransport.Stdio;
return IsRemoteScope() ? ConfiguredTransport.HttpRemote : ConfiguredTransport.Http;
}
/// <summary> /// <summary>
/// Normalizes a URL so that we consistently store just the base (no trailing slash/path). /// Normalizes a URL so that we consistently store just the base (no trailing slash/path).
/// </summary> /// </summary>
private static string NormalizeBaseUrl(string value) private static string NormalizeBaseUrl(string value, string defaultUrl)
{ {
if (string.IsNullOrWhiteSpace(value)) if (string.IsNullOrWhiteSpace(value))
{ {
return DefaultBaseUrl; return defaultUrl;
} }
string trimmed = value.Trim(); string trimmed = value.Trim();

View File

@ -23,7 +23,8 @@ namespace MCPForUnity.Editor.Models
{ {
Unknown, // Could not determine transport type Unknown, // Could not determine transport type
Stdio, // Client configured for stdio transport Stdio, // Client configured for stdio transport
Http // Client configured for HTTP transport Http, // Client configured for HTTP local transport
HttpRemote // Client configured for HTTP remote-hosted transport
} }
} }

View File

@ -53,6 +53,7 @@ namespace MCPForUnity.Editor.Services
private string _uvxPathOverride; private string _uvxPathOverride;
private string _gitUrlOverride; private string _gitUrlOverride;
private string _httpBaseUrl; private string _httpBaseUrl;
private string _httpRemoteBaseUrl;
private string _claudeCliPathOverride; private string _claudeCliPathOverride;
private string _httpTransportScope; private string _httpTransportScope;
private int _unitySocketPort; private int _unitySocketPort;
@ -94,11 +95,17 @@ namespace MCPForUnity.Editor.Services
public string GitUrlOverride => _gitUrlOverride; public string GitUrlOverride => _gitUrlOverride;
/// <summary> /// <summary>
/// HTTP base URL for the MCP server. /// HTTP base URL for the local MCP server.
/// Default: empty string /// Default: empty string
/// </summary> /// </summary>
public string HttpBaseUrl => _httpBaseUrl; public string HttpBaseUrl => _httpBaseUrl;
/// <summary>
/// HTTP base URL for the remote-hosted MCP server.
/// Default: empty string
/// </summary>
public string HttpRemoteBaseUrl => _httpRemoteBaseUrl;
/// <summary> /// <summary>
/// Custom path override for Claude CLI executable. /// Custom path override for Claude CLI executable.
/// Default: empty string (auto-detect) /// Default: empty string (auto-detect)
@ -135,6 +142,7 @@ namespace MCPForUnity.Editor.Services
_uvxPathOverride = EditorPrefs.GetString(EditorPrefKeys.UvxPathOverride, string.Empty); _uvxPathOverride = EditorPrefs.GetString(EditorPrefKeys.UvxPathOverride, string.Empty);
_gitUrlOverride = EditorPrefs.GetString(EditorPrefKeys.GitUrlOverride, string.Empty); _gitUrlOverride = EditorPrefs.GetString(EditorPrefKeys.GitUrlOverride, string.Empty);
_httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty); _httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty);
_httpRemoteBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpRemoteBaseUrl, string.Empty);
_claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty); _claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty);
_httpTransportScope = EditorPrefs.GetString(EditorPrefKeys.HttpTransportScope, string.Empty); _httpTransportScope = EditorPrefs.GetString(EditorPrefKeys.HttpTransportScope, string.Empty);
_unitySocketPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0); _unitySocketPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0);
@ -234,6 +242,20 @@ namespace MCPForUnity.Editor.Services
} }
} }
/// <summary>
/// Set HttpRemoteBaseUrl and update cache + EditorPrefs atomically.
/// </summary>
public void SetHttpRemoteBaseUrl(string value)
{
value = value ?? string.Empty;
if (_httpRemoteBaseUrl != value)
{
_httpRemoteBaseUrl = value;
EditorPrefs.SetString(EditorPrefKeys.HttpRemoteBaseUrl, value);
OnConfigurationChanged?.Invoke(nameof(HttpRemoteBaseUrl));
}
}
/// <summary> /// <summary>
/// Set ClaudeCliPathOverride and update cache + EditorPrefs atomically. /// Set ClaudeCliPathOverride and update cache + EditorPrefs atomically.
/// </summary> /// </summary>
@ -304,6 +326,9 @@ namespace MCPForUnity.Editor.Services
case nameof(HttpBaseUrl): case nameof(HttpBaseUrl):
_httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty); _httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty);
break; break;
case nameof(HttpRemoteBaseUrl):
_httpRemoteBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpRemoteBaseUrl, string.Empty);
break;
case nameof(ClaudeCliPathOverride): case nameof(ClaudeCliPathOverride):
_claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty); _claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty);
break; break;

View File

@ -30,7 +30,7 @@ namespace MCPForUnity.Editor.Services.Server
return false; return false;
} }
string httpUrl = HttpEndpointUtility.GetBaseUrl(); string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
if (!IsLocalUrl(httpUrl)) if (!IsLocalUrl(httpUrl))
{ {
error = $"The configured URL ({httpUrl}) is not a local address. Local server launch only works for localhost."; error = $"The configured URL ({httpUrl}) is not a local address. Local server launch only works for localhost.";

View File

@ -1,7 +1,7 @@
using System; using System;
using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Collections.Generic;
using System.Net.Sockets; using System.Net.Sockets;
using MCPForUnity.Editor.Constants; using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers; using MCPForUnity.Editor.Helpers;
@ -158,7 +158,7 @@ namespace MCPForUnity.Editor.Services
if (success) if (success)
{ {
McpLog.Debug($"uv cache cleared successfully: {stdout}"); McpLog.Info($"uv cache cleared successfully: {stdout}");
return true; return true;
} }
string combinedOutput = string.Join( string combinedOutput = string.Join(
@ -253,7 +253,7 @@ namespace MCPForUnity.Editor.Services
// If the port is still occupied, don't start and explain why (avoid confusing "refusing to stop" warnings). // If the port is still occupied, don't start and explain why (avoid confusing "refusing to stop" warnings).
try try
{ {
string httpUrl = HttpEndpointUtility.GetBaseUrl(); string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
if (Uri.TryCreate(httpUrl, UriKind.Absolute, out var uri) && uri.Port > 0) if (Uri.TryCreate(httpUrl, UriKind.Absolute, out var uri) && uri.Port > 0)
{ {
var remaining = GetListeningProcessIdsForPort(uri.Port); var remaining = GetListeningProcessIdsForPort(uri.Port);
@ -274,7 +274,7 @@ namespace MCPForUnity.Editor.Services
// Note: Dev mode cache-busting is handled by `uvx --no-cache --refresh` in the generated command. // Note: Dev mode cache-busting is handled by `uvx --no-cache --refresh` in the generated command.
// Create a per-launch token + pidfile path so Stop can be deterministic without relying on port/PID heuristics. // Create a per-launch token + pidfile path so Stop can be deterministic without relying on port/PID heuristics.
string baseUrlForPid = HttpEndpointUtility.GetBaseUrl(); string baseUrlForPid = HttpEndpointUtility.GetLocalBaseUrl();
Uri.TryCreate(baseUrlForPid, UriKind.Absolute, out var uriForPid); Uri.TryCreate(baseUrlForPid, UriKind.Absolute, out var uriForPid);
int portForPid = uriForPid?.Port ?? 0; int portForPid = uriForPid?.Port ?? 0;
string instanceToken = Guid.NewGuid().ToString("N"); string instanceToken = Guid.NewGuid().ToString("N");
@ -350,7 +350,7 @@ namespace MCPForUnity.Editor.Services
int port = 0; int port = 0;
if (!TryGetPortFromPidFilePath(pidFilePath, out port) || port <= 0) if (!TryGetPortFromPidFilePath(pidFilePath, out port) || port <= 0)
{ {
string baseUrl = HttpEndpointUtility.GetBaseUrl(); string baseUrl = HttpEndpointUtility.GetLocalBaseUrl();
if (IsLocalUrl(baseUrl) if (IsLocalUrl(baseUrl)
&& Uri.TryCreate(baseUrl, UriKind.Absolute, out var uri) && Uri.TryCreate(baseUrl, UriKind.Absolute, out var uri)
&& uri.Port > 0) && uri.Port > 0)
@ -371,7 +371,7 @@ namespace MCPForUnity.Editor.Services
{ {
try try
{ {
string httpUrl = HttpEndpointUtility.GetBaseUrl(); string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
if (!IsLocalUrl(httpUrl)) if (!IsLocalUrl(httpUrl))
{ {
return false; return false;
@ -433,7 +433,7 @@ namespace MCPForUnity.Editor.Services
{ {
try try
{ {
string httpUrl = HttpEndpointUtility.GetBaseUrl(); string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
if (!IsLocalUrl(httpUrl)) if (!IsLocalUrl(httpUrl))
{ {
return false; return false;
@ -500,7 +500,7 @@ namespace MCPForUnity.Editor.Services
private bool StopLocalHttpServerInternal(bool quiet, int? portOverride = null, bool allowNonLocalUrl = false) private bool StopLocalHttpServerInternal(bool quiet, int? portOverride = null, bool allowNonLocalUrl = false)
{ {
string httpUrl = HttpEndpointUtility.GetBaseUrl(); string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
if (!allowNonLocalUrl && !IsLocalUrl(httpUrl)) if (!allowNonLocalUrl && !IsLocalUrl(httpUrl))
{ {
if (!quiet) if (!quiet)
@ -836,7 +836,7 @@ namespace MCPForUnity.Editor.Services
/// </summary> /// </summary>
public bool IsLocalUrl() public bool IsLocalUrl()
{ {
string httpUrl = HttpEndpointUtility.GetBaseUrl(); string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
return IsLocalUrl(httpUrl); return IsLocalUrl(httpUrl);
} }

View File

@ -67,7 +67,7 @@ namespace MCPForUnity.Editor.Services.Transport
{ {
McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}"); McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}");
} }
UpdateState(mode, TransportState.Disconnected(client.TransportName, "Failed to start")); UpdateState(mode, TransportState.Disconnected(client.TransportName, client.State?.Error ?? "Failed to start"));
return false; return false;
} }

View File

@ -6,6 +6,7 @@ using System.Net.WebSockets;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers; using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services; using MCPForUnity.Editor.Services;
using MCPForUnity.Editor.Services.Transport; using MCPForUnity.Editor.Services.Transport;
@ -56,6 +57,7 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
private volatile bool _isConnected; private volatile bool _isConnected;
private int _isReconnectingFlag; private int _isReconnectingFlag;
private TransportState _state = TransportState.Disconnected(TransportDisplayName, "Transport not started"); private TransportState _state = TransportState.Disconnected(TransportDisplayName, "Transport not started");
private string _apiKey;
private bool _disposed; private bool _disposed;
public WebSocketTransportClient(IToolDiscoveryService toolDiscoveryService = null) public WebSocketTransportClient(IToolDiscoveryService toolDiscoveryService = null)
@ -80,6 +82,9 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
_projectName = ProjectIdentityUtility.GetProjectName(); _projectName = ProjectIdentityUtility.GetProjectName();
_projectHash = ProjectIdentityUtility.GetProjectHash(); _projectHash = ProjectIdentityUtility.GetProjectHash();
_unityVersion = Application.unityVersion; _unityVersion = Application.unityVersion;
_apiKey = HttpEndpointUtility.IsRemoteScope()
? EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty)
: string.Empty;
// Get project root path (strip /Assets from dataPath) for focus nudging // Get project root path (strip /Assets from dataPath) for focus nudging
string dataPath = Application.dataPath; string dataPath = Application.dataPath;
@ -214,13 +219,21 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
_socket = new ClientWebSocket(); _socket = new ClientWebSocket();
_socket.Options.KeepAliveInterval = _socketKeepAliveInterval; _socket.Options.KeepAliveInterval = _socketKeepAliveInterval;
// Add API key header if configured (for remote-hosted mode)
if (!string.IsNullOrEmpty(_apiKey))
{
_socket.Options.SetRequestHeader(AuthConstants.ApiKeyHeader, _apiKey);
}
try try
{ {
await _socket.ConnectAsync(_endpointUri, connectionToken).ConfigureAwait(false); await _socket.ConnectAsync(_endpointUri, connectionToken).ConfigureAwait(false);
} }
catch (Exception ex) catch (Exception ex)
{ {
McpLog.Error($"[WebSocket] Connection failed: {ex.Message}"); string errorMsg = "Connection failed. Check that the server URL is correct, the server is running, and your API key (if required) is valid.";
McpLog.Error($"[WebSocket] {errorMsg} (Detail: {ex.Message})");
_state = TransportState.Disconnected(TransportDisplayName, errorMsg);
return false; return false;
} }
@ -232,7 +245,9 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
} }
catch (Exception ex) catch (Exception ex)
{ {
McpLog.Error($"[WebSocket] Registration failed: {ex.Message}"); string regMsg = $"Registration with server failed: {ex.Message}";
McpLog.Error($"[WebSocket] {regMsg}");
_state = TransportState.Disconnected(TransportDisplayName, regMsg);
return false; return false;
} }

View File

@ -275,6 +275,7 @@ namespace MCPForUnity.Editor.Windows.Components.ClientConfig
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl(); string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts(); var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts();
bool shouldForceRefresh = AssetPathUtility.ShouldForceUvxRefresh(); bool shouldForceRefresh = AssetPathUtility.ShouldForceUvxRefresh();
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
// Compute pathPrepend on main thread // Compute pathPrepend on main thread
string pathPrepend = null; string pathPrepend = null;
@ -296,10 +297,12 @@ namespace MCPForUnity.Editor.Windows.Components.ClientConfig
{ {
if (client is ClaudeCliMcpConfigurator cliConfigurator) if (client is ClaudeCliMcpConfigurator cliConfigurator)
{ {
var serverTransport = HttpEndpointUtility.GetCurrentServerTransport();
cliConfigurator.ConfigureWithCapturedValues( cliConfigurator.ConfigureWithCapturedValues(
projectDir, claudePath, pathPrepend, projectDir, claudePath, pathPrepend,
useHttpTransport, httpUrl, useHttpTransport, httpUrl,
uvxPath, gitUrl, packageName, shouldForceRefresh); uvxPath, gitUrl, packageName, shouldForceRefresh,
apiKey, serverTransport);
} }
return (success: true, error: (string)null); return (success: true, error: (string)null);
} }
@ -525,12 +528,11 @@ namespace MCPForUnity.Editor.Windows.Components.ClientConfig
return; return;
} }
// Check for transport mismatch // Check for transport mismatch (3-way: Stdio, Http, HttpRemote)
bool hasTransportMismatch = false; bool hasTransportMismatch = false;
if (client.ConfiguredTransport != ConfiguredTransport.Unknown) if (client.ConfiguredTransport != ConfiguredTransport.Unknown)
{ {
bool serverUsesHttp = EditorConfigurationCache.Instance.UseHttpTransport; ConfiguredTransport serverTransport = HttpEndpointUtility.GetCurrentServerTransport();
ConfiguredTransport serverTransport = serverUsesHttp ? ConfiguredTransport.Http : ConfiguredTransport.Stdio;
hasTransportMismatch = client.ConfiguredTransport != serverTransport; hasTransportMismatch = client.ConfiguredTransport != serverTransport;
} }

View File

@ -45,6 +45,13 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
private Label connectionStatusLabel; private Label connectionStatusLabel;
private Button connectionToggleButton; private Button connectionToggleButton;
// API Key UI Elements (for remote-hosted mode)
private VisualElement apiKeyRow;
private TextField apiKeyField;
private Button getApiKeyButton;
private Button clearApiKeyButton;
private string cachedLoginUrl;
private bool connectionToggleInProgress; private bool connectionToggleInProgress;
private bool httpServerToggleInProgress; private bool httpServerToggleInProgress;
private Task verificationTask; private Task verificationTask;
@ -93,6 +100,12 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
statusIndicator = Root.Q<VisualElement>("status-indicator"); statusIndicator = Root.Q<VisualElement>("status-indicator");
connectionStatusLabel = Root.Q<Label>("connection-status"); connectionStatusLabel = Root.Q<Label>("connection-status");
connectionToggleButton = Root.Q<Button>("connection-toggle"); connectionToggleButton = Root.Q<Button>("connection-toggle");
// API Key UI Elements
apiKeyRow = Root.Q<VisualElement>("api-key-row");
apiKeyField = Root.Q<TextField>("api-key-field");
getApiKeyButton = Root.Q<Button>("get-api-key-button");
clearApiKeyButton = Root.Q<Button>("clear-api-key-button");
} }
private void InitializeUI() private void InitializeUI()
@ -139,6 +152,15 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
httpUrlField.value = HttpEndpointUtility.GetBaseUrl(); httpUrlField.value = HttpEndpointUtility.GetBaseUrl();
// Initialize API key field
if (apiKeyField != null)
{
apiKeyField.value = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
apiKeyField.tooltip = "API key for remote-hosted MCP server authentication";
apiKeyField.isPasswordField = true;
apiKeyField.maskChar = '*';
}
int unityPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0); int unityPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0);
if (unityPort == 0) if (unityPort == 0)
{ {
@ -170,6 +192,8 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
EditorConfigurationCache.Instance.SetHttpTransportScope(scope); EditorConfigurationCache.Instance.SetHttpTransportScope(scope);
} }
// Swap the displayed URL to match the newly selected scope
SyncUrlFieldToScope();
UpdateHttpFieldVisibility(); UpdateHttpFieldVisibility();
RefreshHttpUi(); RefreshHttpUi();
UpdateConnectionStatus(); UpdateConnectionStatus();
@ -247,6 +271,30 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
}); });
connectionToggleButton.clicked += OnConnectionToggleClicked; connectionToggleButton.clicked += OnConnectionToggleClicked;
// API Key field callbacks
if (apiKeyField != null)
{
apiKeyField.RegisterCallback<FocusOutEvent>(_ => PersistApiKeyFromField());
apiKeyField.RegisterCallback<KeyDownEvent>(evt =>
{
if (evt.keyCode == KeyCode.Return || evt.keyCode == KeyCode.KeypadEnter)
{
PersistApiKeyFromField();
evt.StopPropagation();
}
});
}
if (getApiKeyButton != null)
{
getApiKeyButton.clicked += OnGetApiKeyClicked;
}
if (clearApiKeyButton != null)
{
clearApiKeyButton.clicked += OnClearApiKeyClicked;
}
} }
private void PersistHttpUrlFromField() private void PersistHttpUrlFromField()
@ -259,6 +307,8 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
HttpEndpointUtility.SaveBaseUrl(httpUrlField.text); HttpEndpointUtility.SaveBaseUrl(httpUrlField.text);
// Update displayed value to normalized form without re-triggering callbacks/caret jumps. // Update displayed value to normalized form without re-triggering callbacks/caret jumps.
httpUrlField.SetValueWithoutNotify(HttpEndpointUtility.GetBaseUrl()); httpUrlField.SetValueWithoutNotify(HttpEndpointUtility.GetBaseUrl());
// Invalidate cached login URL so it is re-fetched for the new base URL.
cachedLoginUrl = null;
OnManualConfigUpdateRequested?.Invoke(); OnManualConfigUpdateRequested?.Invoke();
RefreshHttpUi(); RefreshHttpUi();
} }
@ -338,7 +388,15 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
statusIndicator.RemoveFromClassList("connected"); statusIndicator.RemoveFromClassList("connected");
statusIndicator.AddToClassList("disconnected"); statusIndicator.AddToClassList("disconnected");
connectionToggleButton.text = "Start Session"; connectionToggleButton.text = "Start Session";
connectionToggleButton.SetEnabled(true);
// Disable Start Session for HTTP Remote when no API key is set
bool httpRemoteNeedsKey = transportDropdown != null
&& (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPRemote
&& string.IsNullOrEmpty(EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty));
connectionToggleButton.SetEnabled(!httpRemoteNeedsKey);
connectionToggleButton.tooltip = httpRemoteNeedsKey
? "An API key is required for HTTP Remote. Enter one above."
: string.Empty;
} }
unityPortField.SetEnabled(!isStdioResuming); unityPortField.SetEnabled(!isStdioResuming);
@ -439,10 +497,19 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
{ {
bool useHttp = (TransportProtocol)transportDropdown.value != TransportProtocol.Stdio; bool useHttp = (TransportProtocol)transportDropdown.value != TransportProtocol.Stdio;
bool httpLocalSelected = IsHttpLocalSelected(); bool httpLocalSelected = IsHttpLocalSelected();
bool httpRemoteSelected = transportDropdown != null && (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPRemote;
httpUrlRow.style.display = useHttp ? DisplayStyle.Flex : DisplayStyle.None; httpUrlRow.style.display = useHttp ? DisplayStyle.Flex : DisplayStyle.None;
httpServerControlRow.style.display = useHttp && httpLocalSelected ? DisplayStyle.Flex : DisplayStyle.None; httpServerControlRow.style.display = useHttp && httpLocalSelected ? DisplayStyle.Flex : DisplayStyle.None;
unitySocketPortRow.style.display = useHttp ? DisplayStyle.None : DisplayStyle.Flex; unitySocketPortRow.style.display = useHttp ? DisplayStyle.None : DisplayStyle.Flex;
// Manual Server Launch foldout only relevant for HTTP Local
if (manualCommandFoldout != null)
manualCommandFoldout.style.display = httpLocalSelected ? DisplayStyle.Flex : DisplayStyle.None;
// API key fields only visible in HTTP Remote mode
if (apiKeyRow != null)
apiKeyRow.style.display = httpRemoteSelected ? DisplayStyle.Flex : DisplayStyle.None;
} }
private bool IsHttpLocalSelected() private bool IsHttpLocalSelected()
@ -450,6 +517,13 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
return transportDropdown != null && (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPLocal; return transportDropdown != null && (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPLocal;
} }
private void SyncUrlFieldToScope()
{
if (httpUrlField == null) return;
httpUrlField.SetValueWithoutNotify(HttpEndpointUtility.GetBaseUrl());
cachedLoginUrl = null;
}
private void UpdateStartHttpButtonState() private void UpdateStartHttpButtonState()
{ {
if (startHttpServerButton == null) if (startHttpServerButton == null)
@ -674,7 +748,13 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
} }
else else
{ {
McpLog.Warn("Failed to start MCP bridge"); var mode = EditorConfigurationCache.Instance.UseHttpTransport
? TransportMode.Http : TransportMode.Stdio;
var state = MCPServiceLocator.TransportManager.GetState(mode);
string errorMsg = state?.Error
?? "Failed to start the MCP session. Check the server URL and that the server is running.";
EditorUtility.DisplayDialog("Connection Failed", errorMsg, "OK");
McpLog.Warn($"Failed to start MCP bridge: {errorMsg}");
} }
} }
} }
@ -720,6 +800,110 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
} }
} }
private void PersistApiKeyFromField()
{
if (apiKeyField == null)
{
return;
}
string apiKey = apiKeyField.text?.Trim() ?? string.Empty;
string existingKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
if (apiKey != existingKey)
{
EditorPrefs.SetString(EditorPrefKeys.ApiKey, apiKey);
OnManualConfigUpdateRequested?.Invoke();
UpdateConnectionStatus();
McpLog.Info(string.IsNullOrEmpty(apiKey) ? "API key cleared" : "API key updated");
}
}
private async void OnGetApiKeyClicked()
{
if (getApiKeyButton != null)
{
getApiKeyButton.SetEnabled(false);
}
try
{
string loginUrl = await GetLoginUrlAsync();
if (string.IsNullOrEmpty(loginUrl))
{
EditorUtility.DisplayDialog("API Key",
"API key management is not available for this server. Contact your server administrator.",
"OK");
return;
}
Application.OpenURL(loginUrl);
}
catch (Exception ex)
{
McpLog.Error($"Failed to get login URL: {ex.Message}");
EditorUtility.DisplayDialog("Error",
$"Failed to get API key login URL:\n\n{ex.Message}",
"OK");
}
finally
{
if (getApiKeyButton != null)
{
getApiKeyButton.SetEnabled(true);
}
}
}
private async Task<string> GetLoginUrlAsync()
{
if (!string.IsNullOrEmpty(cachedLoginUrl))
{
return cachedLoginUrl;
}
string baseUrl = HttpEndpointUtility.GetBaseUrl();
string loginUrlEndpoint = $"{baseUrl.TrimEnd('/')}/api/auth/login-url";
try
{
using (var client = new System.Net.Http.HttpClient())
{
client.Timeout = TimeSpan.FromSeconds(10);
var response = await client.GetAsync(loginUrlEndpoint);
if (response.IsSuccessStatusCode)
{
string json = await response.Content.ReadAsStringAsync();
var result = Newtonsoft.Json.Linq.JObject.Parse(json);
if (result.Value<bool>("success"))
{
cachedLoginUrl = result.Value<string>("login_url");
return cachedLoginUrl;
}
}
}
}
catch (Exception ex)
{
McpLog.Debug($"Failed to fetch login URL from {loginUrlEndpoint}: {ex.Message}");
}
return null;
}
private void OnClearApiKeyClicked()
{
EditorPrefs.SetString(EditorPrefKeys.ApiKey, string.Empty);
if (apiKeyField != null)
{
apiKeyField.SetValueWithoutNotify(string.Empty);
}
OnManualConfigUpdateRequested?.Invoke();
UpdateConnectionStatus();
McpLog.Info("API key cleared");
}
public async Task VerifyBridgeConnectionAsync() public async Task VerifyBridgeConnectionAsync()
{ {
// Prevent concurrent verification calls // Prevent concurrent verification calls
@ -810,17 +994,16 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
return; return;
} }
// Determine the server's current transport setting // Determine the server's current transport setting (3-way: Stdio, Http, HttpRemote)
bool serverUsesHttp = EditorConfigurationCache.Instance.UseHttpTransport; ConfiguredTransport serverTransport = HttpEndpointUtility.GetCurrentServerTransport();
ConfiguredTransport serverTransport = serverUsesHttp ? ConfiguredTransport.Http : ConfiguredTransport.Stdio;
// Check for mismatch // Check for mismatch
bool hasMismatch = clientTransport != serverTransport; bool hasMismatch = clientTransport != serverTransport;
if (hasMismatch) if (hasMismatch)
{ {
string clientTransportName = clientTransport == ConfiguredTransport.Http ? "HTTP" : "stdio"; string clientTransportName = TransportDisplayName(clientTransport);
string serverTransportName = serverTransport == ConfiguredTransport.Http ? "HTTP" : "stdio"; string serverTransportName = TransportDisplayName(serverTransport);
transportMismatchText.text = $"⚠ {clientName} is configured for \"{clientTransportName}\" but server is set to \"{serverTransportName}\". " + transportMismatchText.text = $"⚠ {clientName} is configured for \"{clientTransportName}\" but server is set to \"{serverTransportName}\". " +
"Click \"Configure\" in Client Configuration to update."; "Click \"Configure\" in Client Configuration to update.";
@ -839,5 +1022,16 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
{ {
transportMismatchWarning?.RemoveFromClassList("visible"); transportMismatchWarning?.RemoveFromClassList("visible");
} }
private static string TransportDisplayName(ConfiguredTransport transport)
{
return transport switch
{
ConfiguredTransport.Stdio => "stdio",
ConfiguredTransport.Http => "HTTP Local",
ConfiguredTransport.HttpRemote => "HTTP Remote",
_ => "unknown"
};
}
} }
} }

View File

@ -14,6 +14,16 @@
<ui:Label text="HTTP URL:" class="setting-label" /> <ui:Label text="HTTP URL:" class="setting-label" />
<ui:TextField name="http-url" class="url-field" /> <ui:TextField name="http-url" class="url-field" />
</ui:VisualElement> </ui:VisualElement>
<ui:VisualElement name="api-key-row" style="margin-bottom: 4px;">
<ui:VisualElement class="setting-row">
<ui:Label text="API Key:" class="setting-label" />
<ui:TextField name="api-key-field" password="true" class="url-field" />
</ui:VisualElement>
<ui:VisualElement style="flex-direction: row; justify-content: flex-end;">
<ui:Button name="get-api-key-button" text="Get API Key" class="action-button" />
<ui:Button name="clear-api-key-button" text="Clear" class="action-button" />
</ui:VisualElement>
</ui:VisualElement>
<ui:VisualElement class="setting-row" name="http-server-control-row"> <ui:VisualElement class="setting-row" name="http-server-control-row">
<ui:Label text="Local Server:" class="setting-label" /> <ui:Label text="Local Server:" class="setting-label" />
<ui:Button name="start-http-server-button" text="Start Server" class="action-button start-server-button" /> <ui:Button name="start-http-server-button" text="Start Server" class="action-button start-server-button" />

View File

@ -56,6 +56,7 @@ namespace MCPForUnity.Editor.Windows
{ EditorPrefKeys.ClaudeCliPathOverride, EditorPrefType.String }, { EditorPrefKeys.ClaudeCliPathOverride, EditorPrefType.String },
{ EditorPrefKeys.UvxPathOverride, EditorPrefType.String }, { EditorPrefKeys.UvxPathOverride, EditorPrefType.String },
{ EditorPrefKeys.HttpBaseUrl, EditorPrefType.String }, { EditorPrefKeys.HttpBaseUrl, EditorPrefType.String },
{ EditorPrefKeys.HttpRemoteBaseUrl, EditorPrefType.String },
{ EditorPrefKeys.HttpTransportScope, EditorPrefType.String }, { EditorPrefKeys.HttpTransportScope, EditorPrefType.String },
{ EditorPrefKeys.SessionId, EditorPrefType.String }, { EditorPrefKeys.SessionId, EditorPrefType.String },
{ EditorPrefKeys.WebSocketUrlOverride, EditorPrefType.String }, { EditorPrefKeys.WebSocketUrlOverride, EditorPrefType.String },

View File

@ -63,6 +63,54 @@ docker run -p 8080:8080 -e LOG_LEVEL=DEBUG msanatan/mcp-for-unity-server:latest
--- ---
## Remote-Hosted Mode
To deploy as a shared remote service with API key authentication and per-user session isolation, pass `--http-remote-hosted` along with an API key validation URL:
```bash
docker run -p 8080:8080 \
-e UNITY_MCP_HTTP_REMOTE_HOSTED=true \
-e UNITY_MCP_API_KEY_VALIDATION_URL=https://auth.example.com/api/validate-key \
-e UNITY_MCP_API_KEY_LOGIN_URL=https://app.example.com/api-keys \
msanatan/mcp-for-unity-server:latest
```
In this mode:
- All MCP tool/resource calls and Unity plugin WebSocket connections require a valid `X-API-Key` header.
- Each user only sees Unity instances that connected with their API key.
- Users must explicitly call `set_active_instance` to select a Unity instance.
**Remote-hosted environment variables:**
| Variable | Description |
|----------|-------------|
| `UNITY_MCP_HTTP_REMOTE_HOSTED` | Enable remote-hosted mode (`true`, `1`, or `yes`) |
| `UNITY_MCP_API_KEY_VALIDATION_URL` | External endpoint to validate API keys (required) |
| `UNITY_MCP_API_KEY_LOGIN_URL` | URL where users can obtain/manage API keys |
| `UNITY_MCP_API_KEY_CACHE_TTL` | Cache TTL for validated keys in seconds (default: `300`) |
| `UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER` | Header name for server-to-auth-service authentication |
| `UNITY_MCP_API_KEY_SERVICE_TOKEN` | Token value sent to the auth service |
**MCP client config with API key:**
```json
{
"mcpServers": {
"UnityMCP": {
"url": "http://your-server:8080/mcp",
"headers": {
"X-API-Key": "<your-api-key>"
}
}
}
}
```
For full details, see the [Remote Server Auth Guide](https://github.com/CoplayDev/unity-mcp/blob/main/docs/guides/REMOTE_SERVER_AUTH.md).
---
## Example Prompts ## Example Prompts
Once connected, try these commands in your AI assistant: Once connected, try these commands in your AI assistant:

View File

@ -113,12 +113,124 @@ uv run src/main.py --transport stdio
## Configuration ## Configuration
The server connects to Unity Editor automatically when both are running. No additional configuration needed. The server connects to Unity Editor automatically when both are running. Most users do not need to change any settings.
**Environment Variables:** ### CLI options
- `DISABLE_TELEMETRY=true` - Opt out of anonymous usage analytics These options apply to the `mcp-for-unity` command (whether run via `uvx`, Docker, or `python src/main.py`).
- `LOG_LEVEL=DEBUG` - Enable detailed logging (default: INFO)
- `--transport {stdio,http}` - Transport protocol (default: `stdio`)
- `--http-url URL` - Base URL used to derive host/port defaults (default: `http://localhost:8080`)
- `--http-host HOST` - Override HTTP bind host (overrides URL host)
- `--http-port PORT` - Override HTTP bind port (overrides URL port)
- `--http-remote-hosted` - Treat HTTP transport as remotely hosted
- Requires API key authentication (see below)
- Disables local/CLI-only HTTP routes (`/api/command`, `/api/instances`, `/api/custom-tools`)
- Forces explicit Unity instance selection for MCP tool/resource calls
- Isolates Unity sessions per user
- `--api-key-validation-url URL` - External endpoint to validate API keys (required when `--http-remote-hosted` is set)
- `--api-key-login-url URL` - URL where users can obtain/manage API keys (served by `/api/auth/login-url`)
- `--api-key-cache-ttl SECONDS` - Cache duration for validated keys (default: `300`)
- `--api-key-service-token-header HEADER` - Header name for server-to-auth-service authentication (e.g. `X-Service-Token`)
- `--api-key-service-token TOKEN` - Token value sent to the auth service for server authentication
- `--default-instance INSTANCE` - Default Unity instance to target (project name, hash, or `Name@hash`)
- `--project-scoped-tools` - Keep custom tools scoped to the active Unity project and enable the custom tools resource
- `--unity-instance-token TOKEN` - Optional per-launch token set by Unity for deterministic lifecycle management
- `--pidfile PATH` - Optional path where the server writes its PID on startup (used by Unity-managed terminal launches)
### Environment variables
- `UNITY_MCP_TRANSPORT` - Transport protocol: `stdio` or `http`
- `UNITY_MCP_HTTP_URL` - HTTP server URL (default: `http://localhost:8080`)
- `UNITY_MCP_HTTP_HOST` - HTTP bind host (overrides URL host)
- `UNITY_MCP_HTTP_PORT` - HTTP bind port (overrides URL port)
- `UNITY_MCP_HTTP_REMOTE_HOSTED` - Enable remote-hosted mode (`true`, `1`, or `yes`)
- `UNITY_MCP_DEFAULT_INSTANCE` - Default Unity instance to target (project name, hash, or `Name@hash`)
- `UNITY_MCP_SKIP_STARTUP_CONNECT=1` - Skip initial Unity connection attempt on startup
API key authentication (remote-hosted mode):
- `UNITY_MCP_API_KEY_VALIDATION_URL` - External endpoint to validate API keys
- `UNITY_MCP_API_KEY_LOGIN_URL` - URL where users can obtain/manage API keys
- `UNITY_MCP_API_KEY_CACHE_TTL` - Cache TTL for validated keys in seconds (default: `300`)
- `UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER` - Header name for server-to-auth-service authentication
- `UNITY_MCP_API_KEY_SERVICE_TOKEN` - Token value sent to the auth service for server authentication
Telemetry:
- `DISABLE_TELEMETRY=1` - Disable anonymous telemetry (opt-out)
- `UNITY_MCP_DISABLE_TELEMETRY=1` - Same as `DISABLE_TELEMETRY`
- `MCP_DISABLE_TELEMETRY=1` - Same as `DISABLE_TELEMETRY`
- `UNITY_MCP_TELEMETRY_ENDPOINT` - Override telemetry endpoint URL
- `UNITY_MCP_TELEMETRY_TIMEOUT` - Override telemetry request timeout (seconds)
### Examples
**Stdio (default):**
```bash
uvx --from mcpforunityserver mcp-for-unity --transport stdio
```
**HTTP (local):**
```bash
uvx --from mcpforunityserver mcp-for-unity --transport http --http-host 127.0.0.1 --http-port 8080
```
**HTTP (remote-hosted with API key auth):**
```bash
uvx --from mcpforunityserver mcp-for-unity \
--transport http \
--http-host 0.0.0.0 \
--http-port 8080 \
--http-remote-hosted \
--api-key-validation-url https://auth.example.com/api/validate-key \
--api-key-login-url https://app.example.com/api-keys
```
**Disable telemetry:**
```bash
DISABLE_TELEMETRY=1 uvx --from mcpforunityserver mcp-for-unity --transport stdio
```
---
## Remote-Hosted Mode
When deploying the server as a shared remote service (e.g. for a team or Asset Store users), enable `--http-remote-hosted` to activate API key authentication and per-user session isolation.
**Requirements:**
- An external HTTP endpoint that validates API keys. The server POSTs `{"api_key": "..."}` and expects `{"valid": true, "user_id": "..."}` or `{"valid": false}` in response.
- `--api-key-validation-url` must be provided (or `UNITY_MCP_API_KEY_VALIDATION_URL`). The server exits with code 1 if this is missing.
**What changes in remote-hosted mode:**
- All MCP tool/resource calls and Unity plugin WebSocket connections require a valid `X-API-Key` header.
- Each user only sees Unity instances that connected with their API key (session isolation).
- Auto-selection of a sole Unity instance is disabled; users must explicitly call `set_active_instance`.
- CLI REST routes (`/api/command`, `/api/instances`, `/api/custom-tools`) are disabled.
- `/health` and `/api/auth/login-url` remain accessible without authentication.
**MCP client config with API key:**
```json
{
"mcpServers": {
"UnityMCP": {
"url": "http://remote-server:8080/mcp",
"headers": {
"X-API-Key": "<your-api-key>"
}
}
}
}
```
For full details, see [Remote Server Auth Guide](../docs/guides/REMOTE_SERVER_AUTH.md) and [Architecture Reference](../docs/reference/REMOTE_SERVER_AUTH_ARCHITECTURE.md).
--- ---

View File

@ -182,38 +182,34 @@ async def list_unity_instances(config: Optional[CLIConfig] = None) -> Dict[str,
""" """
cfg = config or get_config() cfg = config or get_config()
# Try the new /api/instances endpoint first, fall back to /plugin/sessions url = f"http://{cfg.host}:{cfg.port}/api/instances"
urls_to_try = [
f"http://{cfg.host}:{cfg.port}/api/instances",
f"http://{cfg.host}:{cfg.port}/plugin/sessions",
]
async with httpx.AsyncClient() as client:
for url in urls_to_try:
try: try:
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=10) response = await client.get(url, timeout=10)
if response.status_code == 200: response.raise_for_status()
data = response.json() data = response.json()
# Normalize response format
if "instances" in data: if "instances" in data:
return data return data
elif "sessions" in data: except httpx.ConnectError as e:
# Convert sessions format to instances format
instances = []
for session_id, details in data["sessions"].items():
instances.append({
"session_id": session_id,
"project": details.get("project", "Unknown"),
"hash": details.get("hash", ""),
"unity_version": details.get("unity_version", "Unknown"),
"connected_at": details.get("connected_at", ""),
})
return {"success": True, "instances": instances}
except Exception:
continue
raise UnityConnectionError( raise UnityConnectionError(
"Failed to list Unity instances: No working endpoint found") f"Cannot connect to Unity MCP server at {cfg.host}:{cfg.port}. "
f"Make sure the server is running and Unity is connected.\n"
f"Error: {e}"
)
except httpx.TimeoutException:
raise UnityConnectionError(
"Connection to Unity timed out while listing instances. "
"Unity may be busy or unresponsive."
)
except httpx.HTTPStatusError as e:
raise UnityConnectionError(
f"HTTP error from server: {e.response.status_code} - {e.response.text}"
)
except Exception as e:
raise UnityConnectionError(f"Unexpected error: {e}")
raise UnityConnectionError("Failed to list Unity instances")
def run_list_instances(config: Optional[CLIConfig] = None) -> Dict[str, Any]: def run_list_instances(config: Optional[CLIConfig] = None) -> Dict[str, Any]:

View File

@ -15,6 +15,21 @@ class ServerConfig:
unity_port: int = 6400 unity_port: int = 6400
mcp_port: int = 6500 mcp_port: int = 6500
# Transport settings
transport_mode: str = "stdio"
# HTTP transport behaviour
http_remote_hosted: bool = False
# API key authentication (required when http_remote_hosted=True)
api_key_validation_url: str | None = None # POST endpoint to validate keys
api_key_login_url: str | None = None # URL for users to get/manage keys
# Cache TTL in seconds (5 min default)
api_key_cache_ttl: float = 300.0
# Optional service token for authenticating to the validation endpoint
api_key_service_token_header: str | None = None # e.g. "X-Service-Token"
api_key_service_token: str | None = None # The token value
# Connection settings # Connection settings
connection_timeout: float = 30.0 connection_timeout: float = 30.0
buffer_size: int = 16 * 1024 * 1024 # 16MB buffer buffer_size: int = 16 * 1024 * 1024 # 16MB buffer

View File

@ -0,0 +1,4 @@
"""Server-wide protocol constants."""
# HTTP header name for API key authentication
API_KEY_HEADER = "X-API-Key"

View File

@ -3,6 +3,7 @@ from transport.unity_instance_middleware import (
UnityInstanceMiddleware, UnityInstanceMiddleware,
get_unity_instance_middleware get_unity_instance_middleware
) )
from services.api_key_service import ApiKeyService
from transport.legacy.unity_connection import get_unity_connection_pool, UnityConnectionPool from transport.legacy.unity_connection import get_unity_connection_pool, UnityConnectionPool
from services.tools import register_all_tools from services.tools import register_all_tools
from core.telemetry import record_milestone, record_telemetry, MilestoneType, RecordType, get_package_version from core.telemetry import record_milestone, record_telemetry, MilestoneType, RecordType, get_package_version
@ -312,6 +313,15 @@ Payload sizing & paging (important):
""" """
def _normalize_instance_token(instance_token: str | None) -> tuple[str | None, str | None]:
if not instance_token:
return None, None
if "@" in instance_token:
name_part, _, hash_part = instance_token.partition("@")
return (name_part or None), (hash_part or None)
return None, instance_token
def create_mcp_server(project_scoped_tools: bool) -> FastMCP: def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
mcp = FastMCP( mcp = FastMCP(
name="mcp-for-unity-server", name="mcp-for-unity-server",
@ -332,14 +342,24 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
"message": "MCP for Unity server is running" "message": "MCP for Unity server is running"
}) })
def _normalize_instance_token(instance_token: str | None) -> tuple[str | None, str | None]: @mcp.custom_route("/api/auth/login-url", methods=["GET"])
if not instance_token: async def auth_login_url(_: Request) -> JSONResponse:
return None, None """Return the login URL for users to obtain/manage API keys."""
if "@" in instance_token: if not config.api_key_login_url:
name_part, _, hash_part = instance_token.partition("@") return JSONResponse(
return (name_part or None), (hash_part or None) {
return None, instance_token "success": False,
"error": "API key management not configured. Contact your server administrator.",
},
status_code=404,
)
return JSONResponse({
"success": True,
"login_url": config.api_key_login_url,
})
# Only expose CLI routes if running locally (not in remote hosted mode)
if not config.http_remote_hosted:
@mcp.custom_route("/api/command", methods=["POST"]) @mcp.custom_route("/api/command", methods=["POST"])
async def cli_command_route(request: Request) -> JSONResponse: async def cli_command_route(request: Request) -> JSONResponse:
"""REST endpoint for CLI commands to Unity.""" """REST endpoint for CLI commands to Unity."""
@ -364,7 +384,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
# Find target session # Find target session
session_id = None session_id = None
session_details = None session_details = None
instance_name, instance_hash = _normalize_instance_token(unity_instance) instance_name, instance_hash = _normalize_instance_token(
unity_instance)
if unity_instance: if unity_instance:
# Try to match by hash or project name # Try to match by hash or project name
for sid, details in sessions.sessions.items(): for sid, details in sessions.sessions.items():
@ -391,19 +412,23 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
tool_name = None tool_name = None
tool_params = {} tool_params = {}
if isinstance(params, dict): if isinstance(params, dict):
tool_name = params.get("tool_name") or params.get("name") tool_name = params.get(
tool_params = params.get("parameters") or params.get("params") or {} "tool_name") or params.get("name")
tool_params = params.get(
"parameters") or params.get("params") or {}
if not tool_name: if not tool_name:
return JSONResponse( return JSONResponse(
{"success": False, "error": "Missing 'tool_name' for execute_custom_tool"}, {"success": False,
"error": "Missing 'tool_name' for execute_custom_tool"},
status_code=400, status_code=400,
) )
if tool_params is None: if tool_params is None:
tool_params = {} tool_params = {}
if not isinstance(tool_params, dict): if not isinstance(tool_params, dict):
return JSONResponse( return JSONResponse(
{"success": False, "error": "Tool parameters must be an object/dict"}, {"success": False,
"error": "Tool parameters must be an object/dict"},
status_code=400, status_code=400,
) )
@ -416,7 +441,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
unity_instance_hint) unity_instance_hint)
if not project_id: if not project_id:
return JSONResponse( return JSONResponse(
{"success": False, "error": "Could not resolve project id for custom tool"}, {"success": False,
"error": "Could not resolve project id for custom tool"},
status_code=400, status_code=400,
) )
@ -431,7 +457,25 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
return JSONResponse(result) return JSONResponse(result)
except Exception as e: except Exception as e:
logger.error(f"CLI command error: {e}") logger.exception("CLI command error: %s", e)
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
@mcp.custom_route("/api/instances", methods=["GET"])
async def cli_instances_route(_: Request) -> JSONResponse:
"""REST endpoint to list connected Unity instances."""
try:
sessions = await PluginHub.get_sessions()
instances = []
for session_id, details in sessions.sessions.items():
instances.append({
"session_id": session_id,
"project": details.project,
"hash": details.hash,
"unity_version": details.unity_version,
"connected_at": details.connected_at,
})
return JSONResponse({"success": True, "instances": instances})
except Exception as e:
return JSONResponse({"success": False, "error": str(e)}, status_code=500) return JSONResponse({"success": False, "error": str(e)}, status_code=500)
@mcp.custom_route("/api/custom-tools", methods=["GET"]) @mcp.custom_route("/api/custom-tools", methods=["GET"])
@ -439,7 +483,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
"""REST endpoint to list custom tools for the active Unity project.""" """REST endpoint to list custom tools for the active Unity project."""
try: try:
unity_instance = request.query_params.get("instance") unity_instance = request.query_params.get("instance")
instance_name, instance_hash = _normalize_instance_token(unity_instance) instance_name, instance_hash = _normalize_instance_token(
unity_instance)
sessions = await PluginHub.get_sessions() sessions = await PluginHub.get_sessions()
if not sessions.sessions: if not sessions.sessions:
@ -475,7 +520,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
unity_instance_hint) unity_instance_hint)
if not project_id: if not project_id:
return JSONResponse( return JSONResponse(
{"success": False, "error": "Could not resolve project id for custom tools"}, {"success": False,
"error": "Could not resolve project id for custom tools"},
status_code=400, status_code=400,
) )
@ -492,38 +538,29 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
"tools": tools_payload, "tools": tools_payload,
}) })
except Exception as e: except Exception as e:
logger.error(f"CLI custom tools error: {e}") logger.exception("CLI custom tools error: %s", e)
return JSONResponse({"success": False, "error": str(e)}, status_code=500) return JSONResponse({"success": False, "error": str(e)}, status_code=500)
@mcp.custom_route("/api/instances", methods=["GET"])
async def cli_instances_route(_: Request) -> JSONResponse:
"""REST endpoint to list connected Unity instances."""
try:
sessions = await PluginHub.get_sessions()
instances = []
for session_id, details in sessions.sessions.items():
instances.append({
"session_id": session_id,
"project": details.project,
"hash": details.hash,
"unity_version": details.unity_version,
"connected_at": details.connected_at,
})
return JSONResponse({"success": True, "instances": instances})
except Exception as e:
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
@mcp.custom_route("/plugin/sessions", methods=["GET"])
async def plugin_sessions_route(_: Request) -> JSONResponse:
data = await PluginHub.get_sessions()
return JSONResponse(data.model_dump())
# Initialize and register middleware for session-based Unity instance routing # Initialize and register middleware for session-based Unity instance routing
# Using the singleton getter ensures we use the same instance everywhere # Using the singleton getter ensures we use the same instance everywhere
unity_middleware = get_unity_instance_middleware() unity_middleware = get_unity_instance_middleware()
mcp.add_middleware(unity_middleware) mcp.add_middleware(unity_middleware)
logger.info("Registered Unity instance middleware for session-based routing") logger.info("Registered Unity instance middleware for session-based routing")
# Initialize API key authentication if in remote-hosted mode
if config.http_remote_hosted and config.api_key_validation_url:
ApiKeyService(
validation_url=config.api_key_validation_url,
cache_ttl=config.api_key_cache_ttl,
service_token_header=config.api_key_service_token_header,
service_token=config.api_key_service_token,
)
logger.info(
"Initialized API key authentication service (validation URL: %s, TTL: %.0fs)",
config.api_key_validation_url,
config.api_key_cache_ttl,
)
# Mount plugin websocket hub at /hub/plugin when HTTP transport is active # Mount plugin websocket hub at /hub/plugin when HTTP transport is active
existing_routes = [ existing_routes = [
route for route in mcp._get_additional_http_routes() route for route in mcp._get_additional_http_routes()
@ -610,6 +647,54 @@ Examples:
help="HTTP server port (overrides URL port). " help="HTTP server port (overrides URL port). "
"Overrides UNITY_MCP_HTTP_PORT environment variable." "Overrides UNITY_MCP_HTTP_PORT environment variable."
) )
parser.add_argument(
"--http-remote-hosted",
action="store_true",
help="Treat HTTP transport as remotely hosted (forces explicit Unity instance selection). "
"Can also set via UNITY_MCP_HTTP_REMOTE_HOSTED=true."
)
parser.add_argument(
"--api-key-validation-url",
type=str,
default=None,
metavar="URL",
help="External URL to validate API keys (POST with {'api_key': '...'}). "
"Required when --http-remote-hosted is set. "
"Can also set via UNITY_MCP_API_KEY_VALIDATION_URL."
)
parser.add_argument(
"--api-key-login-url",
type=str,
default=None,
metavar="URL",
help="URL where users can obtain/manage API keys. "
"Returned by /api/auth/login-url endpoint. "
"Can also set via UNITY_MCP_API_KEY_LOGIN_URL."
)
parser.add_argument(
"--api-key-cache-ttl",
type=float,
default=300.0,
metavar="SECONDS",
help="Cache TTL for validated API keys in seconds (default: 300). "
"Can also set via UNITY_MCP_API_KEY_CACHE_TTL."
)
parser.add_argument(
"--api-key-service-token-header",
type=str,
default=None,
metavar="HEADER",
help="Header name for service token sent to validation endpoint (e.g. X-Service-Token). "
"Can also set via UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER."
)
parser.add_argument(
"--api-key-service-token",
type=str,
default=None,
metavar="TOKEN",
help="Service token value sent to validation endpoint for server authentication. "
"WARNING: Prefer UNITY_MCP_API_KEY_SERVICE_TOKEN env var in production to avoid process listing exposure."
)
parser.add_argument( parser.add_argument(
"--unity-instance-token", "--unity-instance-token",
type=str, type=str,
@ -629,7 +714,8 @@ Examples:
parser.add_argument( parser.add_argument(
"--project-scoped-tools", "--project-scoped-tools",
action="store_true", action="store_true",
help="Keep custom tools scoped to the active Unity project and enable the custom tools resource." help="Keep custom tools scoped to the active Unity project and enable the custom tools resource. "
"Can also set via UNITY_MCP_PROJECT_SCOPED_TOOLS=true."
) )
args = parser.parse_args() args = parser.parse_args()
@ -641,10 +727,52 @@ Examples:
f"Using default Unity instance from command-line: {args.default_instance}") f"Using default Unity instance from command-line: {args.default_instance}")
# Set transport mode # Set transport mode
transport_mode = args.transport or os.environ.get( config.transport_mode = args.transport or os.environ.get(
"UNITY_MCP_TRANSPORT", "stdio") "UNITY_MCP_TRANSPORT", "stdio")
os.environ["UNITY_MCP_TRANSPORT"] = transport_mode logger.info(f"Transport mode: {config.transport_mode}")
logger.info(f"Transport mode: {transport_mode}")
config.http_remote_hosted = (
bool(args.http_remote_hosted)
or os.environ.get("UNITY_MCP_HTTP_REMOTE_HOSTED", "").lower() in ("true", "1", "yes", "on")
)
# API key authentication configuration
config.api_key_validation_url = (
args.api_key_validation_url
or os.environ.get("UNITY_MCP_API_KEY_VALIDATION_URL")
)
config.api_key_login_url = (
args.api_key_login_url
or os.environ.get("UNITY_MCP_API_KEY_LOGIN_URL")
)
try:
cache_ttl_env = os.environ.get("UNITY_MCP_API_KEY_CACHE_TTL")
config.api_key_cache_ttl = (
float(cache_ttl_env) if cache_ttl_env else args.api_key_cache_ttl
)
except ValueError:
logger.warning(
"Invalid UNITY_MCP_API_KEY_CACHE_TTL value, using default 300.0"
)
config.api_key_cache_ttl = 300.0
# Service token for authenticating to validation endpoint
config.api_key_service_token_header = (
args.api_key_service_token_header
or os.environ.get("UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER")
)
config.api_key_service_token = (
args.api_key_service_token
or os.environ.get("UNITY_MCP_API_KEY_SERVICE_TOKEN")
)
# Validate: remote-hosted HTTP mode requires API key validation URL
if config.http_remote_hosted and config.transport_mode == "http" and not config.api_key_validation_url:
logger.error(
"--http-remote-hosted requires --api-key-validation-url or "
"UNITY_MCP_API_KEY_VALIDATION_URL environment variable"
)
raise SystemExit(1)
http_url = os.environ.get("UNITY_MCP_HTTP_URL", args.http_url) http_url = os.environ.get("UNITY_MCP_HTTP_URL", args.http_url)
parsed_url = urlparse(http_url) parsed_url = urlparse(http_url)
@ -688,10 +816,14 @@ Examples:
if args.http_port: if args.http_port:
logger.info(f"HTTP port override: {http_port}") logger.info(f"HTTP port override: {http_port}")
mcp = create_mcp_server(args.project_scoped_tools) project_scoped_tools = (
bool(args.project_scoped_tools)
or os.environ.get("UNITY_MCP_PROJECT_SCOPED_TOOLS", "").lower() in ("true", "1", "yes", "on")
)
mcp = create_mcp_server(project_scoped_tools)
# Determine transport mode # Determine transport mode
if transport_mode == 'http': if config.transport_mode == 'http':
# Use HTTP transport for FastMCP # Use HTTP transport for FastMCP
transport = 'http' transport = 'http'
# Use the parsed host and port from URL/args # Use the parsed host and port from URL/args

View File

@ -0,0 +1,235 @@
"""API Key validation service for remote-hosted mode."""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Any
import httpx
logger = logging.getLogger("mcp-for-unity-server")
@dataclass
class ValidationResult:
"""Result of an API key validation."""
valid: bool
user_id: str | None = None
metadata: dict[str, Any] | None = None
error: str | None = None
cacheable: bool = True
class ApiKeyService:
"""Service for validating API keys against an external auth endpoint.
Follows the class-level singleton pattern for global access by MCP tools.
"""
_instance: "ApiKeyService | None" = None
# Request defaults (sensible hardening)
REQUEST_TIMEOUT: float = 5.0
MAX_RETRIES: int = 1
def __init__(
self,
validation_url: str,
cache_ttl: float = 300.0,
service_token_header: str | None = None,
service_token: str | None = None,
):
"""Initialize the API key service.
Args:
validation_url: External URL to validate API keys (POST with {"api_key": "..."})
cache_ttl: Cache TTL for validated keys in seconds (default: 300)
service_token_header: Optional header name for service authentication (e.g. "X-Service-Token")
service_token: Optional token value for service authentication
"""
self._validation_url = validation_url
self._cache_ttl = cache_ttl
self._service_token_header = service_token_header
self._service_token = service_token
# Cache: api_key -> (valid, user_id, metadata, expires_at)
self._cache: dict[str, tuple[bool, str |
None, dict[str, Any] | None, float]] = {}
self._cache_lock = asyncio.Lock()
ApiKeyService._instance = self
@classmethod
def get_instance(cls) -> "ApiKeyService":
"""Get the singleton instance.
Raises:
RuntimeError: If the service has not been initialized.
"""
if cls._instance is None:
raise RuntimeError("ApiKeyService not initialized")
return cls._instance
@classmethod
def is_initialized(cls) -> bool:
"""Check if the service has been initialized."""
return cls._instance is not None
async def validate(self, api_key: str) -> ValidationResult:
"""Validate an API key.
Returns:
ValidationResult with valid=True and user_id if valid,
or valid=False with error message if invalid.
"""
if not api_key:
return ValidationResult(valid=False, error="API key required")
# Check cache first
async with self._cache_lock:
cached = self._cache.get(api_key)
if cached is not None:
valid, user_id, metadata, expires_at = cached
if time.time() < expires_at:
if valid:
return ValidationResult(valid=True, user_id=user_id, metadata=metadata)
else:
return ValidationResult(valid=False, error="Invalid API key")
else:
# Expired, remove from cache
del self._cache[api_key]
# Call external validation URL
result = await self._validate_external(api_key)
# Only cache definitive results (valid keys and confirmed-invalid keys).
# Transient failures (auth service unavailable, timeouts, etc.) should
# not be cached to avoid locking out users during service outages.
if result.cacheable:
async with self._cache_lock:
expires_at = time.time() + self._cache_ttl
self._cache[api_key] = (
result.valid,
result.user_id,
result.metadata,
expires_at,
)
return result
async def _validate_external(self, api_key: str) -> ValidationResult:
"""Call external validation endpoint.
Failure mode: fail closed (treat as invalid on errors).
"""
# Redact API key from logs
redacted_key = f"{api_key[:4]}...{api_key[-4:]}" if len(
api_key) > 8 else "***"
for attempt in range(self.MAX_RETRIES + 1):
try:
async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
# Build request headers
headers = {"Content-Type": "application/json"}
if self._service_token_header and self._service_token:
headers[self._service_token_header] = self._service_token
response = await client.post(
self._validation_url,
json={"api_key": api_key},
headers=headers,
)
if response.status_code == 200:
data = response.json()
if data.get("valid"):
return ValidationResult(
valid=True,
user_id=data.get("user_id"),
metadata=data.get("metadata"),
)
else:
return ValidationResult(
valid=False,
error=data.get("error", "Invalid API key"),
)
elif response.status_code == 401:
return ValidationResult(valid=False, error="Invalid API key")
else:
logger.warning(
"API key validation returned status %d for key %s",
response.status_code,
redacted_key,
)
# Fail closed but don't cache (transient service error)
return ValidationResult(
valid=False,
error=f"Auth service error (status {response.status_code})",
cacheable=False,
)
except httpx.TimeoutException:
if attempt < self.MAX_RETRIES:
logger.debug(
"API key validation timeout for key %s, retrying...",
redacted_key,
)
await asyncio.sleep(0.1 * (attempt + 1))
continue
logger.warning(
"API key validation timeout for key %s after %d attempts",
redacted_key,
attempt + 1,
)
return ValidationResult(
valid=False,
error="Auth service timeout",
cacheable=False,
)
except httpx.RequestError as exc:
if attempt < self.MAX_RETRIES:
logger.debug(
"API key validation request error for key %s: %s, retrying...",
redacted_key,
exc,
)
await asyncio.sleep(0.1 * (attempt + 1))
continue
logger.warning(
"API key validation request error for key %s: %s",
redacted_key,
exc,
)
return ValidationResult(
valid=False,
error="Auth service unavailable",
cacheable=False,
)
except Exception as exc:
logger.error(
"Unexpected error validating API key %s: %s",
redacted_key,
exc,
)
return ValidationResult(
valid=False,
error="Auth service error",
cacheable=False,
)
# Should not reach here, but fail closed
return ValidationResult(valid=False, error="Auth service error", cacheable=False)
async def invalidate_cache(self, api_key: str) -> None:
"""Remove an API key from the cache."""
async with self._cache_lock:
self._cache.pop(api_key, None)
async def clear_cache(self) -> None:
"""Clear all cached validations."""
async with self._cache_lock:
self._cache.clear()
__all__ = ["ApiKeyService", "ValidationResult"]

View File

@ -5,12 +5,14 @@ from typing import Any
from fastmcp import Context from fastmcp import Context
from pydantic import BaseModel from pydantic import BaseModel
from core.config import config
from models import MCPResponse from models import MCPResponse
from services.registry import mcp_for_unity_resource from services.registry import mcp_for_unity_resource
from services.tools import get_unity_instance_from_context from services.tools import get_unity_instance_from_context
from services.state.external_changes_scanner import external_changes_scanner from services.state.external_changes_scanner import external_changes_scanner
import transport.unity_transport as unity_transport import transport.unity_transport as unity_transport
from transport.legacy.unity_connection import async_send_command_with_retry from transport.legacy.unity_connection import async_send_command_with_retry
from transport.plugin_hub import PluginHub
class EditorStateUnity(BaseModel): class EditorStateUnity(BaseModel):
@ -132,17 +134,15 @@ async def infer_single_instance_id(ctx: Context) -> str | None:
""" """
await ctx.info("If exactly one Unity instance is connected, return its Name@hash id.") await ctx.info("If exactly one Unity instance is connected, return its Name@hash id.")
try: transport = (config.transport_mode or "stdio").lower()
transport = unity_transport._current_transport()
except Exception:
transport = None
if transport == "http": if transport == "http":
# HTTP/WebSocket transport: derive from PluginHub sessions. # HTTP/WebSocket transport: derive from PluginHub sessions.
try: try:
from transport.plugin_hub import PluginHub # In remote-hosted mode, filter sessions by user_id
user_id = ctx.get_state(
sessions_data = await PluginHub.get_sessions() "user_id") if config.http_remote_hosted else None
sessions_data = await PluginHub.get_sessions(user_id=user_id)
sessions = sessions_data.sessions if hasattr( sessions = sessions_data.sessions if hasattr(
sessions_data, "sessions") else {} sessions_data, "sessions") else {}
if isinstance(sessions, dict) and len(sessions) == 1: if isinstance(sessions, dict) and len(sessions) == 1:

View File

@ -7,7 +7,7 @@ from fastmcp import Context
from services.registry import mcp_for_unity_resource from services.registry import mcp_for_unity_resource
from transport.legacy.unity_connection import get_unity_connection_pool from transport.legacy.unity_connection import get_unity_connection_pool
from transport.plugin_hub import PluginHub from transport.plugin_hub import PluginHub
from transport.unity_transport import _current_transport from core.config import config
@mcp_for_unity_resource( @mcp_for_unity_resource(
@ -36,10 +36,13 @@ async def unity_instances(ctx: Context) -> dict[str, Any]:
await ctx.info("Listing Unity instances") await ctx.info("Listing Unity instances")
try: try:
transport = _current_transport() transport = (config.transport_mode or "stdio").lower()
if transport == "http": if transport == "http":
# HTTP/WebSocket transport: query PluginHub # HTTP/WebSocket transport: query PluginHub
sessions_data = await PluginHub.get_sessions() # In remote-hosted mode, filter sessions by user_id
user_id = ctx.get_state(
"user_id") if config.http_remote_hosted else None
sessions_data = await PluginHub.get_sessions(user_id=user_id)
sessions = sessions_data.sessions sessions = sessions_data.sessions
instances = [] instances = []

View File

@ -8,7 +8,7 @@ from services.registry import mcp_for_unity_tool
from transport.legacy.unity_connection import get_unity_connection_pool from transport.legacy.unity_connection import get_unity_connection_pool
from transport.unity_instance_middleware import get_unity_instance_middleware from transport.unity_instance_middleware import get_unity_instance_middleware
from transport.plugin_hub import PluginHub from transport.plugin_hub import PluginHub
from transport.unity_transport import _current_transport from core.config import config
@mcp_for_unity_tool( @mcp_for_unity_tool(
@ -21,11 +21,14 @@ async def set_active_instance(
ctx: Context, ctx: Context,
instance: Annotated[str, "Target instance (Name@hash or hash prefix)"] instance: Annotated[str, "Target instance (Name@hash or hash prefix)"]
) -> dict[str, Any]: ) -> dict[str, Any]:
transport = _current_transport() transport = (config.transport_mode or "stdio").lower()
# Discover running instances based on transport # Discover running instances based on transport
if transport == "http": if transport == "http":
sessions_data = await PluginHub.get_sessions() # In remote-hosted mode, filter sessions by user_id
user_id = ctx.get_state(
"user_id") if config.http_remote_hosted else None
sessions_data = await PluginHub.get_sessions(user_id=user_id)
sessions = sessions_data.sessions sessions = sessions_data.sessions
instances = [] instances = []
for session_id, session in sessions.items(): for session_id, session in sessions.items():

View File

@ -13,8 +13,10 @@ from starlette.endpoints import WebSocketEndpoint
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from core.config import config from core.config import config
from core.constants import API_KEY_HEADER
from models.models import MCPResponse from models.models import MCPResponse
from transport.plugin_registry import PluginRegistry from transport.plugin_registry import PluginRegistry
from services.api_key_service import ApiKeyService
from transport.models import ( from transport.models import (
WelcomeMessage, WelcomeMessage,
RegisteredMessage, RegisteredMessage,
@ -38,6 +40,22 @@ class NoUnitySessionError(RuntimeError):
"""Raised when no Unity plugins are available.""" """Raised when no Unity plugins are available."""
class InstanceSelectionRequiredError(RuntimeError):
"""Raised when the caller must explicitly select a Unity instance."""
_SELECTION_REQUIRED = (
"Unity instance selection is required. "
"Call set_active_instance with Name@hash from mcpforunity://instances."
)
_MULTIPLE_INSTANCES = (
"Multiple Unity instances are connected. "
"Call set_active_instance with Name@hash from mcpforunity://instances."
)
def __init__(self, message: str | None = None):
super().__init__(message or self._SELECTION_REQUIRED)
class PluginHub(WebSocketEndpoint): class PluginHub(WebSocketEndpoint):
"""Manages persistent WebSocket connections to Unity plugins.""" """Manages persistent WebSocket connections to Unity plugins."""
@ -77,6 +95,50 @@ class PluginHub(WebSocketEndpoint):
return cls._registry is not None and cls._lock is not None return cls._registry is not None and cls._lock is not None
async def on_connect(self, websocket: WebSocket) -> None: async def on_connect(self, websocket: WebSocket) -> None:
# Validate API key in remote-hosted mode (fail closed)
if config.http_remote_hosted:
if not ApiKeyService.is_initialized():
logger.debug(
"WebSocket connection rejected: auth service not initialized")
await websocket.close(code=1013, reason="Try again later")
return
api_key = websocket.headers.get(API_KEY_HEADER)
if not api_key:
logger.debug("WebSocket connection rejected: API key required")
await websocket.close(code=4401, reason="API key required")
return
service = ApiKeyService.get_instance()
result = await service.validate(api_key)
if not result.valid:
# Transient auth failures are retryable (1013)
if result.error and any(
indicator in result.error.lower()
for indicator in ("unavailable", "timeout", "service error")
):
logger.debug(
"WebSocket connection rejected: auth service unavailable")
await websocket.close(code=1013, reason="Try again later")
return
logger.debug("WebSocket connection rejected: invalid API key")
await websocket.close(code=4403, reason="Invalid API key")
return
# Both valid and user_id must be present to accept
if not result.user_id:
logger.debug(
"WebSocket connection rejected: validated key missing user_id")
await websocket.close(code=4403, reason="Invalid API key")
return
# Store user_id in websocket state for later use during registration
websocket.state.user_id = result.user_id
websocket.state.api_key_metadata = result.metadata
await websocket.accept() await websocket.accept()
msg = WelcomeMessage( msg = WelcomeMessage(
serverTimeout=self.SERVER_TIMEOUT, serverTimeout=self.SERVER_TIMEOUT,
@ -217,10 +279,15 @@ class PluginHub(WebSocketEndpoint):
cls._pending.pop(command_id, None) cls._pending.pop(command_id, None)
@classmethod @classmethod
async def get_sessions(cls) -> SessionList: async def get_sessions(cls, user_id: str | None = None) -> SessionList:
"""Get all active plugin sessions.
Args:
user_id: If provided (remote-hosted mode), only return sessions for this user.
"""
if cls._registry is None: if cls._registry is None:
return SessionList(sessions={}) return SessionList(sessions={})
sessions = await cls._registry.list_sessions() sessions = await cls._registry.list_sessions(user_id=user_id)
return SessionList( return SessionList(
sessions={ sessions={
session_id: SessionDetails( session_id: SessionDetails(
@ -286,14 +353,22 @@ class PluginHub(WebSocketEndpoint):
raise ValueError( raise ValueError(
"Plugin registration missing project_hash") "Plugin registration missing project_hash")
# Get user_id from websocket state (set during API key validation)
user_id = getattr(websocket.state, "user_id", None)
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# Inform the plugin of its assigned session ID # Inform the plugin of its assigned session ID
response = RegisteredMessage(session_id=session_id) response = RegisteredMessage(session_id=session_id)
await websocket.send_json(response.model_dump()) await websocket.send_json(response.model_dump())
session = await registry.register(session_id, project_name, project_hash, unity_version, project_path) session = await registry.register(session_id, project_name, project_hash, unity_version, project_path, user_id=user_id)
async with lock: async with lock:
cls._connections[session.session_id] = websocket cls._connections[session.session_id] = websocket
if user_id:
logger.info(
f"Plugin registered: {project_name} ({project_hash}) for user {user_id}")
else:
logger.info(f"Plugin registered: {project_name} ({project_hash})") logger.info(f"Plugin registered: {project_name} ({project_hash})")
async def _handle_register_tools(self, websocket: WebSocket, payload: RegisterToolsMessage) -> None: async def _handle_register_tools(self, websocket: WebSocket, payload: RegisterToolsMessage) -> None:
@ -375,13 +450,17 @@ class PluginHub(WebSocketEndpoint):
# Session resolution helpers # Session resolution helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@classmethod @classmethod
async def _resolve_session_id(cls, unity_instance: str | None) -> str: async def _resolve_session_id(cls, unity_instance: str | None, user_id: str | None = None) -> str:
"""Resolve a project hash (Unity instance id) to an active plugin session. """Resolve a project hash (Unity instance id) to an active plugin session.
During Unity domain reloads the plugin's WebSocket session is torn down During Unity domain reloads the plugin's WebSocket session is torn down
and reconnected shortly afterwards. Instead of failing immediately when and reconnected shortly afterwards. Instead of failing immediately when
no sessions are available, we wait for a bounded period for a plugin no sessions are available, we wait for a bounded period for a plugin
to reconnect so in-flight MCP calls can succeed transparently. to reconnect so in-flight MCP calls can succeed transparently.
Args:
unity_instance: Target instance (Name@hash or hash)
user_id: User ID from API key validation (for remote-hosted mode session isolation)
""" """
if cls._registry is None: if cls._registry is None:
raise RuntimeError("Plugin registry not configured") raise RuntimeError("Plugin registry not configured")
@ -411,24 +490,35 @@ class PluginHub(WebSocketEndpoint):
else: else:
target_hash = unity_instance target_hash = unity_instance
async def _try_once() -> tuple[str | None, int]: async def _try_once() -> tuple[str | None, int, bool]:
explicit_required = config.http_remote_hosted
# Prefer a specific Unity instance if one was requested # Prefer a specific Unity instance if one was requested
if target_hash: if target_hash:
# In remote-hosted mode with user_id, use user-scoped lookup
if config.http_remote_hosted and user_id:
session_id = await cls._registry.get_session_id_by_hash(target_hash, user_id)
sessions = await cls._registry.list_sessions(user_id=user_id)
else:
session_id = await cls._registry.get_session_id_by_hash(target_hash) session_id = await cls._registry.get_session_id_by_hash(target_hash)
sessions = await cls._registry.list_sessions() sessions = await cls._registry.list_sessions(user_id=user_id)
return session_id, len(sessions) return session_id, len(sessions), explicit_required
# No target provided: determine if we can auto-select # No target provided: determine if we can auto-select
sessions = await cls._registry.list_sessions() # In remote-hosted mode, filter sessions by user_id
sessions = await cls._registry.list_sessions(user_id=user_id)
count = len(sessions) count = len(sessions)
if count == 0: if count == 0:
return None, count return None, count, explicit_required
if explicit_required:
return None, count, explicit_required
if count == 1: if count == 1:
return next(iter(sessions.keys())), count return next(iter(sessions.keys())), count, explicit_required
# Multiple sessions but no explicit target is ambiguous # Multiple sessions but no explicit target is ambiguous
return None, count return None, count, explicit_required
session_id, session_count = await _try_once() session_id, session_count, explicit_required = await _try_once()
if session_id is None and explicit_required and not target_hash and session_count > 0:
raise InstanceSelectionRequiredError()
deadline = time.monotonic() + max_wait_s deadline = time.monotonic() + max_wait_s
wait_started = None wait_started = None
@ -436,10 +526,10 @@ class PluginHub(WebSocketEndpoint):
# wait politely for a session to appear before surfacing an error. # wait politely for a session to appear before surfacing an error.
while session_id is None and time.monotonic() < deadline: while session_id is None and time.monotonic() < deadline:
if not target_hash and session_count > 1: if not target_hash and session_count > 1:
raise RuntimeError( raise InstanceSelectionRequiredError(
"Multiple Unity instances are connected. " InstanceSelectionRequiredError._MULTIPLE_INSTANCES)
"Call set_active_instance with Name@hash from mcpforunity://instances." if session_id is None and explicit_required and not target_hash and session_count > 0:
) raise InstanceSelectionRequiredError()
if wait_started is None: if wait_started is None:
wait_started = time.monotonic() wait_started = time.monotonic()
logger.debug( logger.debug(
@ -448,7 +538,7 @@ class PluginHub(WebSocketEndpoint):
max_wait_s, max_wait_s,
) )
await asyncio.sleep(sleep_seconds) await asyncio.sleep(sleep_seconds)
session_id, session_count = await _try_once() session_id, session_count, explicit_required = await _try_once()
if session_id is not None and wait_started is not None: if session_id is not None and wait_started is not None:
logger.debug( logger.debug(
@ -457,10 +547,11 @@ class PluginHub(WebSocketEndpoint):
unity_instance or "default", unity_instance or "default",
) )
if session_id is None and not target_hash and session_count > 1: if session_id is None and not target_hash and session_count > 1:
raise RuntimeError( raise InstanceSelectionRequiredError(
"Multiple Unity instances are connected. " InstanceSelectionRequiredError._MULTIPLE_INSTANCES)
"Call set_active_instance with Name@hash from mcpforunity://instances."
) if session_id is None and explicit_required and not target_hash and session_count > 0:
raise InstanceSelectionRequiredError()
if session_id is None: if session_id is None:
logger.warning( logger.warning(
@ -481,9 +572,18 @@ class PluginHub(WebSocketEndpoint):
unity_instance: str | None, unity_instance: str | None,
command_type: str, command_type: str,
params: dict[str, Any], params: dict[str, Any],
user_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Send a command to a Unity instance.
Args:
unity_instance: Target instance (Name@hash or hash)
command_type: Command type to execute
params: Command parameters
user_id: User ID for session isolation in remote-hosted mode
"""
try: try:
session_id = await cls._resolve_session_id(unity_instance) session_id = await cls._resolve_session_id(unity_instance, user_id=user_id)
except NoUnitySessionError: except NoUnitySessionError:
logger.debug( logger.debug(
"Unity session unavailable; returning retry: command=%s instance=%s", "Unity session unavailable; returning retry: command=%s instance=%s",

View File

@ -7,6 +7,7 @@ from datetime import datetime, timezone
import asyncio import asyncio
from core.config import config
from models.models import ToolDefinitionModel from models.models import ToolDefinitionModel
@ -22,7 +23,9 @@ class PluginSession:
connected_at: datetime connected_at: datetime
tools: dict[str, ToolDefinitionModel] = field(default_factory=dict) tools: dict[str, ToolDefinitionModel] = field(default_factory=dict)
project_id: str | None = None project_id: str | None = None
project_path: str | None = None # Full path to project root (for focus nudging) # Full path to project root (for focus nudging)
project_path: str | None = None
user_id: str | None = None # Associated user id (None for local mode)
class PluginRegistry: class PluginRegistry:
@ -31,11 +34,17 @@ class PluginRegistry:
The registry is optimised for quick lookup by either ``session_id`` or The registry is optimised for quick lookup by either ``session_id`` or
``project_hash`` (which is used as the canonical "instance id" across the ``project_hash`` (which is used as the canonical "instance id" across the
HTTP command routing stack). HTTP command routing stack).
In remote-hosted mode, sessions are scoped by (user_id, project_hash) composite key
to ensure session isolation between users.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._sessions: dict[str, PluginSession] = {} self._sessions: dict[str, PluginSession] = {}
# In local mode: project_hash -> session_id
# In remote mode: (user_id, project_hash) -> session_id
self._hash_to_session: dict[str, str] = {} self._hash_to_session: dict[str, str] = {}
self._user_hash_to_session: dict[tuple[str, str], str] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
async def register( async def register(
@ -45,13 +54,16 @@ class PluginRegistry:
project_hash: str, project_hash: str,
unity_version: str, unity_version: str,
project_path: str | None = None, project_path: str | None = None,
user_id: str | None = None,
) -> PluginSession: ) -> PluginSession:
"""Register (or replace) a plugin session. """Register (or replace) a plugin session.
If an existing session already claims the same ``project_hash`` it will be If an existing session already claims the same ``project_hash`` (and ``user_id``
replaced, ensuring that reconnect scenarios always map to the latest in remote-hosted mode) it will be replaced, ensuring that reconnect scenarios
WebSocket connection. always map to the latest WebSocket connection.
""" """
if config.http_remote_hosted and not user_id:
raise ValueError("user_id is required in remote-hosted mode")
async with self._lock: async with self._lock:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@ -63,15 +75,26 @@ class PluginRegistry:
registered_at=now, registered_at=now,
connected_at=now, connected_at=now,
project_path=project_path, project_path=project_path,
user_id=user_id,
) )
# Remove old mapping for this hash if it existed under a different session # Remove old mapping for this hash if it existed under a different session
if user_id:
# Remote-hosted mode: use composite key (user_id, project_hash)
composite_key = (user_id, project_hash)
previous_session_id = self._user_hash_to_session.get(
composite_key)
if previous_session_id and previous_session_id != session_id:
self._sessions.pop(previous_session_id, None)
self._user_hash_to_session[composite_key] = session_id
else:
# Local mode: use project_hash only
previous_session_id = self._hash_to_session.get(project_hash) previous_session_id = self._hash_to_session.get(project_hash)
if previous_session_id and previous_session_id != session_id: if previous_session_id and previous_session_id != session_id:
self._sessions.pop(previous_session_id, None) self._sessions.pop(previous_session_id, None)
self._hash_to_session[project_hash] = session_id
self._sessions[session_id] = session self._sessions[session_id] = session
self._hash_to_session[project_hash] = session_id
return session return session
async def touch(self, session_id: str) -> None: async def touch(self, session_id: str) -> None:
@ -87,12 +110,21 @@ class PluginRegistry:
async with self._lock: async with self._lock:
session = self._sessions.pop(session_id, None) session = self._sessions.pop(session_id, None)
if session and session.project_hash in self._hash_to_session: if session:
# Only delete the mapping if it still points at the removed session. # Clean up hash mappings
if session.project_hash in self._hash_to_session:
mapped = self._hash_to_session.get(session.project_hash) mapped = self._hash_to_session.get(session.project_hash)
if mapped == session_id: if mapped == session_id:
del self._hash_to_session[session.project_hash] del self._hash_to_session[session.project_hash]
# Clean up user-scoped mappings
if session.user_id:
composite_key = (session.user_id, session.project_hash)
if composite_key in self._user_hash_to_session:
mapped = self._user_hash_to_session.get(composite_key)
if mapped == session_id:
del self._user_hash_to_session[composite_key]
async def register_tools_for_session(self, session_id: str, tools: list[ToolDefinitionModel]) -> None: async def register_tools_for_session(self, session_id: str, tools: list[ToolDefinitionModel]) -> None:
"""Register tools for a specific session.""" """Register tools for a specific session."""
async with self._lock: async with self._lock:
@ -110,17 +142,41 @@ class PluginRegistry:
async with self._lock: async with self._lock:
return self._sessions.get(session_id) return self._sessions.get(session_id)
async def get_session_id_by_hash(self, project_hash: str) -> str | None: async def get_session_id_by_hash(self, project_hash: str, user_id: str | None = None) -> str | None:
"""Resolve a ``project_hash`` (Unity instance id) to a session id.""" """Resolve a ``project_hash`` (Unity instance id) to a session id."""
if user_id:
async with self._lock:
return self._user_hash_to_session.get((user_id, project_hash))
else:
async with self._lock: async with self._lock:
return self._hash_to_session.get(project_hash) return self._hash_to_session.get(project_hash)
async def list_sessions(self) -> dict[str, PluginSession]: async def list_sessions(self, user_id: str | None = None) -> dict[str, PluginSession]:
"""Return a shallow copy of all known sessions.""" """Return a shallow copy of sessions.
Args:
user_id: If provided, only return sessions for this user (remote-hosted mode).
If None, return all sessions (local mode only).
Raises:
ValueError: If ``user_id`` is None while running in remote-hosted mode.
This prevents accidentally leaking sessions across users.
"""
if user_id is None and config.http_remote_hosted:
raise ValueError(
"list_sessions requires user_id in remote-hosted mode"
)
async with self._lock: async with self._lock:
if user_id is None:
return dict(self._sessions) return dict(self._sessions)
else:
return {
sid: session
for sid, session in self._sessions.items()
if session.user_id == user_id
}
__all__ = ["PluginRegistry", "PluginSession"] __all__ = ["PluginRegistry", "PluginSession"]

View File

@ -9,6 +9,7 @@ import logging
from fastmcp.server.middleware import Middleware, MiddlewareContext from fastmcp.server.middleware import Middleware, MiddlewareContext
from core.config import config
from transport.plugin_hub import PluginHub from transport.plugin_hub import PluginHub
logger = logging.getLogger("mcp-for-unity-server") logger = logging.getLogger("mcp-for-unity-server")
@ -32,7 +33,12 @@ def get_unity_instance_middleware() -> 'UnityInstanceMiddleware':
def set_unity_instance_middleware(middleware: 'UnityInstanceMiddleware') -> None: def set_unity_instance_middleware(middleware: 'UnityInstanceMiddleware') -> None:
"""Set the global Unity instance middleware (called during server initialization).""" """Replace the global middleware instance.
This is a test seam: production code uses ``get_unity_instance_middleware()``
which lazy-initialises the singleton. Tests call this function to inject a
mock or pre-configured middleware before exercising tool/resource code.
"""
global _unity_instance_middleware global _unity_instance_middleware
_unity_instance_middleware = middleware _unity_instance_middleware = middleware
@ -55,13 +61,18 @@ class UnityInstanceMiddleware(Middleware):
Derive a stable key for the calling session. Derive a stable key for the calling session.
Prioritizes client_id for stability. Prioritizes client_id for stability.
If client_id is missing, falls back to 'global' (assuming single-user local mode), In remote-hosted mode, falls back to user_id for session isolation.
ignoring session_id which can be unstable in some transports/clients. Otherwise falls back to 'global' (assuming single-user local mode).
""" """
client_id = getattr(ctx, "client_id", None) client_id = getattr(ctx, "client_id", None)
if isinstance(client_id, str) and client_id: if isinstance(client_id, str) and client_id:
return client_id return client_id
# In remote-hosted mode, use user_id so different users get isolated instance selections
user_id = ctx.get_state("user_id")
if isinstance(user_id, str) and user_id:
return f"user:{user_id}"
# Fallback to global for local dev stability # Fallback to global for local dev stability
return "global" return "global"
@ -92,10 +103,10 @@ class UnityInstanceMiddleware(Middleware):
to stick for subsequent tool/resource calls in the same session. to stick for subsequent tool/resource calls in the same session.
""" """
try: try:
# Import here to avoid circular dependencies / optional transport modules. transport = (config.transport_mode or "stdio").lower()
from transport.unity_transport import _current_transport # This implicit behavior works well for solo-users, but is dangerous for multi-user setups
if transport == "http" and config.http_remote_hosted:
transport = _current_transport() return None
if PluginHub.is_configured(): if PluginHub.is_configured():
try: try:
sessions_data = await PluginHub.get_sessions() sessions_data = await PluginHub.get_sessions()
@ -172,10 +183,27 @@ class UnityInstanceMiddleware(Middleware):
return None return None
async def _resolve_user_id(self) -> str | None:
"""Extract user_id from the current HTTP request's API key."""
if not config.http_remote_hosted:
return None
# Lazy import to avoid circular dependencies (same pattern as _maybe_autoselect_instance).
from transport.unity_transport import _resolve_user_id_from_request
return await _resolve_user_id_from_request()
async def _inject_unity_instance(self, context: MiddlewareContext) -> None: async def _inject_unity_instance(self, context: MiddlewareContext) -> None:
"""Inject active Unity instance into context if available.""" """Inject active Unity instance and user_id into context if available."""
ctx = context.fastmcp_context ctx = context.fastmcp_context
# Resolve user_id from the HTTP request's API key header
user_id = await self._resolve_user_id()
if config.http_remote_hosted and user_id is None:
raise RuntimeError(
"API key authentication required. Provide a valid X-API-Key header."
)
if user_id:
ctx.set_state("user_id", user_id)
active_instance = self.get_active_instance(ctx) active_instance = self.get_active_instance(ctx)
if not active_instance: if not active_instance:
active_instance = await self._maybe_autoselect_instance(ctx) active_instance = await self._maybe_autoselect_instance(ctx)
@ -193,7 +221,8 @@ class UnityInstanceMiddleware(Middleware):
# resolving session_id might fail if the plugin disconnected # resolving session_id might fail if the plugin disconnected
# We only need session_id for HTTP transport routing. # We only need session_id for HTTP transport routing.
# For stdio, we just need the instance ID. # For stdio, we just need the instance ID.
session_id = await PluginHub._resolve_session_id(active_instance) # Pass user_id for remote-hosted mode session isolation
session_id = await PluginHub._resolve_session_id(active_instance, user_id=user_id)
except (ConnectionError, ValueError, KeyError, TimeoutError) as exc: except (ConnectionError, ValueError, KeyError, TimeoutError) as exc:
# If resolution fails, it means the Unity instance is not reachable via HTTP/WS. # If resolution fails, it means the Unity instance is not reachable via HTTP/WS.
# If we are in stdio mode, this might still be fine if the user is just setting state? # If we are in stdio mode, this might still be fine if the user is just setting state?

View File

@ -1,34 +1,49 @@
"""Transport helpers for routing commands to Unity.""" """Transport helpers for routing commands to Unity."""
from __future__ import annotations from __future__ import annotations
import asyncio import logging
import inspect
import os
from typing import Awaitable, Callable, TypeVar from typing import Awaitable, Callable, TypeVar
from fastmcp import Context
from transport.plugin_hub import PluginHub from transport.plugin_hub import PluginHub
from core.config import config
from core.constants import API_KEY_HEADER
from services.api_key_service import ApiKeyService
from models.models import MCPResponse from models.models import MCPResponse
from models.unity_response import normalize_unity_response from models.unity_response import normalize_unity_response
from services.tools import get_unity_instance_from_context
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger("mcp-for-unity-server")
def _is_http_transport() -> bool: def _is_http_transport() -> bool:
return os.environ.get("UNITY_MCP_TRANSPORT", "stdio").lower() == "http" return config.transport_mode.lower() == "http"
def _current_transport() -> str: async def _resolve_user_id_from_request() -> str | None:
"""Expose the active transport mode as a simple string identifier.""" """Extract user_id from the current HTTP request's API key header."""
return "http" if _is_http_transport() else "stdio" if not config.http_remote_hosted:
return None
if not ApiKeyService.is_initialized():
return None
try:
from fastmcp.server.dependencies import get_http_headers
headers = get_http_headers(include_all=True)
api_key = headers.get(API_KEY_HEADER.lower())
if not api_key:
return None
service = ApiKeyService.get_instance()
result = await service.validate(api_key)
return result.user_id if result.valid else None
except Exception as e:
logger.debug("Failed to resolve user_id from HTTP request: %s", e)
return None
async def send_with_unity_instance( async def send_with_unity_instance(
send_fn: Callable[..., Awaitable[T]], send_fn: Callable[..., Awaitable[T]],
unity_instance: str | None, unity_instance: str | None,
*args, *args,
user_id: str | None = None,
**kwargs, **kwargs,
) -> T: ) -> T:
if _is_http_transport(): if _is_http_transport():
@ -41,11 +56,27 @@ async def send_with_unity_instance(
if not isinstance(params, dict): if not isinstance(params, dict):
raise TypeError( raise TypeError(
"Command parameters must be a dict for HTTP transport") "Command parameters must be a dict for HTTP transport")
# Auto-resolve user_id from HTTP request API key (remote-hosted mode)
if user_id is None:
user_id = await _resolve_user_id_from_request()
# Auth check
if config.http_remote_hosted and not user_id:
return normalize_unity_response(
MCPResponse(
success=False,
error="auth_required",
message="API key required",
).model_dump()
)
try: try:
raw = await PluginHub.send_command_for_instance( raw = await PluginHub.send_command_for_instance(
unity_instance, unity_instance,
command_type, command_type,
params, params,
user_id=user_id,
) )
return normalize_unity_response(raw) return normalize_unity_response(raw)
except Exception as exc: except Exception as exc:

View File

@ -1,5 +1,6 @@
"""Pytest configuration for unity-mcp tests.""" """Pytest configuration for unity-mcp tests."""
import logging import logging
import os
import sys import sys
from pathlib import Path from pathlib import Path
import pytest import pytest
@ -58,3 +59,29 @@ def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
# Reorder: characterization/unit tests first, then integration tests # Reorder: characterization/unit tests first, then integration tests
items[:] = other_tests + integration_tests items[:] = other_tests + integration_tests
@pytest.fixture(autouse=True)
def restore_global_config():
"""Restore global config/env mutations between tests."""
from core.config import config as global_config
prior_env = os.environ.get("UNITY_MCP_TRANSPORT")
prior = {
"transport_mode": global_config.transport_mode,
"http_remote_hosted": global_config.http_remote_hosted,
"api_key_validation_url": global_config.api_key_validation_url,
"api_key_login_url": global_config.api_key_login_url,
"api_key_cache_ttl": global_config.api_key_cache_ttl,
"api_key_service_token_header": global_config.api_key_service_token_header,
"api_key_service_token": global_config.api_key_service_token,
}
yield
if prior_env is None:
os.environ.pop("UNITY_MCP_TRANSPORT", None)
else:
os.environ["UNITY_MCP_TRANSPORT"] = prior_env
for key, value in prior.items():
setattr(global_config, key, value)

View File

@ -0,0 +1,456 @@
"""Tests for ApiKeyService: validation, caching, retries, and singleton lifecycle."""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from services.api_key_service import ApiKeyService, ValidationResult
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Reset the ApiKeyService singleton between tests."""
ApiKeyService._instance = None
yield
ApiKeyService._instance = None
def _make_service(
validation_url="https://auth.example.com/validate",
cache_ttl=300.0,
service_token_header=None,
service_token=None,
):
return ApiKeyService(
validation_url=validation_url,
cache_ttl=cache_ttl,
service_token_header=service_token_header,
service_token=service_token,
)
def _mock_response(status_code=200, json_data=None):
resp = MagicMock(spec=httpx.Response)
resp.status_code = status_code
resp.json.return_value = json_data or {}
return resp
# ---------------------------------------------------------------------------
# Singleton lifecycle
# ---------------------------------------------------------------------------
class TestSingletonLifecycle:
def test_get_instance_before_init_raises(self):
with pytest.raises(RuntimeError, match="not initialized"):
ApiKeyService.get_instance()
def test_is_initialized_false_before_init(self):
assert ApiKeyService.is_initialized() is False
def test_is_initialized_true_after_init(self):
_make_service()
assert ApiKeyService.is_initialized() is True
def test_get_instance_returns_service(self):
svc = _make_service()
assert ApiKeyService.get_instance() is svc
# ---------------------------------------------------------------------------
# Basic validation
# ---------------------------------------------------------------------------
class TestBasicValidation:
@pytest.mark.asyncio
async def test_valid_key(self):
svc = _make_service()
mock_resp = _mock_response(
200, {"valid": True, "user_id": "user-1", "metadata": {"plan": "pro"}})
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = AsyncMock(return_value=mock_resp)
MockClient.return_value = instance
result = await svc.validate("test-valid-key-12345678")
assert result.valid is True
assert result.user_id == "user-1"
assert result.metadata == {"plan": "pro"}
@pytest.mark.asyncio
async def test_invalid_key_200_body(self):
svc = _make_service()
mock_resp = _mock_response(
200, {"valid": False, "error": "Key revoked"})
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = AsyncMock(return_value=mock_resp)
MockClient.return_value = instance
result = await svc.validate("test-invalid-key-1234")
assert result.valid is False
assert result.error == "Key revoked"
@pytest.mark.asyncio
async def test_invalid_key_401_status(self):
svc = _make_service()
mock_resp = _mock_response(401)
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = AsyncMock(return_value=mock_resp)
MockClient.return_value = instance
result = await svc.validate("test-bad-key-12345678")
assert result.valid is False
assert "Invalid API key" in result.error
@pytest.mark.asyncio
async def test_empty_key_fast_path(self):
svc = _make_service()
with patch("httpx.AsyncClient") as MockClient:
result = await svc.validate("")
assert result.valid is False
assert "required" in result.error.lower()
# No HTTP call should have been made
MockClient.assert_not_called()
# ---------------------------------------------------------------------------
# Caching
# ---------------------------------------------------------------------------
class TestCaching:
@pytest.mark.asyncio
async def test_cache_hit_valid_key(self):
svc = _make_service(cache_ttl=300.0)
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
call_count = 0
async def counting_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = counting_post
MockClient.return_value = instance
r1 = await svc.validate("test-cached-valid-key1")
r2 = await svc.validate("test-cached-valid-key1")
assert r1.valid is True
assert r2.valid is True
assert r2.user_id == "u1"
assert call_count == 1 # Only one HTTP call
@pytest.mark.asyncio
async def test_cache_hit_invalid_key(self):
svc = _make_service(cache_ttl=300.0)
mock_resp = _mock_response(200, {"valid": False, "error": "bad"})
call_count = 0
async def counting_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = counting_post
MockClient.return_value = instance
r1 = await svc.validate("test-cached-bad-key12")
r2 = await svc.validate("test-cached-bad-key12")
assert r1.valid is False
assert r2.valid is False
assert call_count == 1
@pytest.mark.asyncio
async def test_cache_expiry(self):
svc = _make_service(cache_ttl=1.0) # 1 second TTL
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
call_count = 0
async def counting_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = counting_post
MockClient.return_value = instance
await svc.validate("test-expiry-key-12345")
assert call_count == 1
# Manually expire the cache entry by manipulating the stored tuple
async with svc._cache_lock:
key = "test-expiry-key-12345"
valid, user_id, metadata, _expires = svc._cache[key]
svc._cache[key] = (valid, user_id, metadata, time.time() - 1)
await svc.validate("test-expiry-key-12345")
assert call_count == 2 # Had to re-validate
@pytest.mark.asyncio
async def test_invalidate_cache(self):
svc = _make_service(cache_ttl=300.0)
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
call_count = 0
async def counting_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = counting_post
MockClient.return_value = instance
await svc.validate("test-invalidate-key12")
assert call_count == 1
await svc.invalidate_cache("test-invalidate-key12")
await svc.validate("test-invalidate-key12")
assert call_count == 2
@pytest.mark.asyncio
async def test_clear_cache(self):
svc = _make_service(cache_ttl=300.0)
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
call_count = 0
async def counting_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = counting_post
MockClient.return_value = instance
await svc.validate("test-clear-key1-12345")
await svc.validate("test-clear-key2-12345")
assert call_count == 2
await svc.clear_cache()
await svc.validate("test-clear-key1-12345")
await svc.validate("test-clear-key2-12345")
assert call_count == 4 # Both had to re-validate
# ---------------------------------------------------------------------------
# Transient failures & retries
# ---------------------------------------------------------------------------
class TestTransientFailures:
@pytest.mark.asyncio
async def test_5xx_not_cached(self):
svc = _make_service(cache_ttl=300.0)
mock_500 = _mock_response(500)
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
responses = [mock_500, mock_500, mock_ok] # Extra for retry
call_idx = 0
async def sequential_post(*args, **kwargs):
nonlocal call_idx
resp = responses[min(call_idx, len(responses) - 1)]
call_idx += 1
return resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = sequential_post
MockClient.return_value = instance
# First call: 500 -> not cached
r1 = await svc.validate("test-5xx-test-key1234")
assert r1.valid is False
assert r1.cacheable is False
# Second call should hit HTTP again (not cached)
r2 = await svc.validate("test-5xx-test-key1234")
# Second call also gets 500 from our mock sequence
assert r2.valid is False
@pytest.mark.asyncio
async def test_timeout_then_retry_succeeds(self):
svc = _make_service()
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
attempt = 0
async def timeout_then_ok(*args, **kwargs):
nonlocal attempt
attempt += 1
if attempt == 1:
raise httpx.TimeoutException("timed out")
return mock_ok
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = timeout_then_ok
MockClient.return_value = instance
result = await svc.validate("test-timeout-retry-ok")
assert result.valid is True
assert result.user_id == "u1"
assert attempt == 2
@pytest.mark.asyncio
async def test_timeout_exhausts_retries(self):
svc = _make_service()
async def always_timeout(*args, **kwargs):
raise httpx.TimeoutException("timed out")
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = always_timeout
MockClient.return_value = instance
result = await svc.validate("test-timeout-exhaust1")
assert result.valid is False
assert "timeout" in result.error.lower()
assert result.cacheable is False
@pytest.mark.asyncio
async def test_request_error_then_retry_succeeds(self):
svc = _make_service()
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
attempt = 0
async def error_then_ok(*args, **kwargs):
nonlocal attempt
attempt += 1
if attempt == 1:
raise httpx.ConnectError("connection refused")
return mock_ok
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = error_then_ok
MockClient.return_value = instance
result = await svc.validate("test-reqerr-retry-ok1")
assert result.valid is True
assert attempt == 2
@pytest.mark.asyncio
async def test_request_error_exhausts_retries(self):
svc = _make_service()
async def always_error(*args, **kwargs):
raise httpx.ConnectError("connection refused")
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = always_error
MockClient.return_value = instance
result = await svc.validate("test-reqerr-exhaust1")
assert result.valid is False
assert "unavailable" in result.error.lower()
assert result.cacheable is False
@pytest.mark.asyncio
async def test_unexpected_exception(self):
svc = _make_service()
async def unexpected(*args, **kwargs):
raise ValueError("something unexpected")
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = unexpected
MockClient.return_value = instance
result = await svc.validate("test-unexpected-err12")
assert result.valid is False
assert result.cacheable is False
# ---------------------------------------------------------------------------
# Service token
# ---------------------------------------------------------------------------
class TestServiceToken:
@pytest.mark.asyncio
async def test_service_token_sent_in_headers(self):
svc = _make_service(
service_token_header="X-Service-Token",
service_token="test-svc-token-123",
)
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
captured_headers = {}
async def capture_post(url, *, json=None, headers=None):
captured_headers.update(headers or {})
return mock_resp
with patch("httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.post = capture_post
MockClient.return_value = instance
await svc.validate("test-svctoken-key1234")
assert captured_headers.get("X-Service-Token") == "test-svc-token-123"
assert captured_headers.get("Content-Type") == "application/json"

View File

@ -0,0 +1,114 @@
"""Tests for auth configuration validation and startup routes."""
import json
import sys
from unittest.mock import MagicMock
import pytest
from core.config import config
from starlette.requests import Request
from starlette.responses import JSONResponse
@pytest.fixture(autouse=True)
def _restore_config(monkeypatch):
"""Prevent main() side effects on the global config from leaking to other tests."""
monkeypatch.setattr(config, "http_remote_hosted", config.http_remote_hosted)
monkeypatch.setattr(config, "api_key_validation_url", config.api_key_validation_url)
monkeypatch.setattr(config, "api_key_login_url", config.api_key_login_url)
monkeypatch.setattr(config, "api_key_cache_ttl", config.api_key_cache_ttl)
monkeypatch.setattr(config, "api_key_service_token_header", config.api_key_service_token_header)
monkeypatch.setattr(config, "api_key_service_token", config.api_key_service_token)
class TestStartupConfigValidation:
def test_remote_hosted_flag_without_validation_url_exits(self, monkeypatch):
"""--http-remote-hosted without --api-key-validation-url should SystemExit(1)."""
monkeypatch.setattr(
sys,
"argv",
[
"main",
"--transport", "http",
"--http-remote-hosted",
# Deliberately omit --api-key-validation-url
],
)
monkeypatch.delenv("UNITY_MCP_API_KEY_VALIDATION_URL", raising=False)
monkeypatch.delenv("UNITY_MCP_HTTP_REMOTE_HOSTED", raising=False)
from main import main
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
def test_remote_hosted_env_var_without_validation_url_exits(self, monkeypatch):
"""UNITY_MCP_HTTP_REMOTE_HOSTED=true without validation URL should SystemExit(1)."""
monkeypatch.setattr(
sys,
"argv",
[
"main",
"--transport", "http",
# No --http-remote-hosted flag
],
)
monkeypatch.setenv("UNITY_MCP_HTTP_REMOTE_HOSTED", "true")
monkeypatch.delenv("UNITY_MCP_API_KEY_VALIDATION_URL", raising=False)
from main import main
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
class TestLoginUrlEndpoint:
"""Test the /api/auth/login-url route handler logic.
These tests replicate the handler inline to avoid full MCP server construction.
The logic mirrors main.py's auth_login_url route exactly.
"""
@staticmethod
async def _auth_login_url(_request):
"""Replicate the route handler from main.py."""
if not config.api_key_login_url:
return JSONResponse(
{
"success": False,
"error": "API key management not configured. Contact your server administrator.",
},
status_code=404,
)
return JSONResponse({
"success": True,
"login_url": config.api_key_login_url,
})
@pytest.mark.asyncio
async def test_login_url_returns_url_when_configured(self, monkeypatch):
monkeypatch.setattr(config, "api_key_login_url",
"https://app.example.com/keys")
response = await self._auth_login_url(MagicMock(spec=Request))
assert response.status_code == 200
body = json.loads(response.body.decode())
assert body["success"] is True
assert body["login_url"] == "https://app.example.com/keys"
@pytest.mark.asyncio
async def test_login_url_returns_404_when_not_configured(self, monkeypatch):
monkeypatch.setattr(config, "api_key_login_url", None)
response = await self._auth_login_url(MagicMock(spec=Request))
assert response.status_code == 404
body = json.loads(response.body.decode())
assert body["success"] is False
assert "not configured" in body["error"]

View File

@ -27,7 +27,7 @@ async def test_plugin_hub_waits_for_reconnection_during_reload():
# Third call: session appears (plugin reconnected) # Third call: session appears (plugin reconnected)
call_count = [0] call_count = [0]
async def mock_list_sessions(): async def mock_list_sessions(**kwargs):
call_count[0] += 1 call_count[0] += 1
if call_count[0] <= 2: if call_count[0] <= 2:
# Plugin not yet reconnected # Plugin not yet reconnected
@ -77,7 +77,7 @@ async def test_plugin_hub_fails_after_timeout():
# Create a mock registry that never returns sessions # Create a mock registry that never returns sessions
mock_registry = AsyncMock(spec=PluginRegistry) mock_registry = AsyncMock(spec=PluginRegistry)
async def mock_list_sessions(): async def mock_list_sessions(**kwargs):
return {} # Never returns sessions return {} # Never returns sessions
mock_registry.list_sessions = mock_list_sessions mock_registry.list_sessions = mock_list_sessions
@ -161,7 +161,7 @@ async def test_read_console_during_simulated_reload(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_plugin_hub_respects_unity_instance_preference(): async def test_plugin_hub_respects_unity_instance_preference():
"""Test that _resolve_session_id prefers a specific Unity instance if requested.""" """Test that _resolve_session_id prefers a specific Unity instance if requested."""
from transport.plugin_hub import PluginHub from transport.plugin_hub import PluginHub, InstanceSelectionRequiredError
from transport.plugin_registry import PluginRegistry, PluginSession from transport.plugin_registry import PluginRegistry, PluginSession
# Create a mock registry with two sessions # Create a mock registry with two sessions
@ -185,19 +185,19 @@ async def test_plugin_hub_respects_unity_instance_preference():
connected_at=now connected_at=now
) )
async def mock_list_sessions(): async def mock_list_sessions(**kwargs):
return { return {
"session-1": session1, "session-1": session1,
"session-2": session2 "session-2": session2
} }
async def mock_get_session_by_hash(project_hash): async def mock_get_session_id_by_hash(project_hash, user_id=None):
if project_hash == "hash2": if project_hash == "hash2":
return "session-2" return "session-2"
return None return None
mock_registry.list_sessions = mock_list_sessions mock_registry.list_sessions = mock_list_sessions
mock_registry.get_session_id_by_hash = mock_get_session_by_hash mock_registry.get_session_id_by_hash = mock_get_session_id_by_hash
# Configure PluginHub with our mock while preserving the original state # Configure PluginHub with our mock while preserving the original state
original_registry = PluginHub._registry original_registry = PluginHub._registry
@ -213,9 +213,8 @@ async def test_plugin_hub_respects_unity_instance_preference():
assert session_id == "session-2" assert session_id == "session-2"
# Request default (no specific instance) # Request default (no specific instance)
with pytest.raises(RuntimeError) as exc: with pytest.raises(InstanceSelectionRequiredError, match="Multiple Unity instances"):
await PluginHub._resolve_session_id(unity_instance=None) await PluginHub._resolve_session_id(unity_instance=None)
assert "Multiple Unity instances are connected" in str(exc.value)
finally: finally:
# Clean up: restore original PluginHub state # Clean up: restore original PluginHub state

View File

@ -4,6 +4,7 @@ import types
from types import SimpleNamespace from types import SimpleNamespace
from .test_helpers import DummyContext from .test_helpers import DummyContext
from core.config import config
class DummyMiddlewareContext: class DummyMiddlewareContext:
@ -25,15 +26,12 @@ def test_auto_selects_single_instance_via_pluginhub(monkeypatch):
plugin_hub.PluginHub = PluginHub plugin_hub.PluginHub = PluginHub
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub) monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
unity_transport = types.ModuleType("transport.unity_transport")
unity_transport._current_transport = lambda: "http"
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False) monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
assert ImportedPluginHub is plugin_hub.PluginHub assert ImportedPluginHub is plugin_hub.PluginHub
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http") monkeypatch.setattr(config, "transport_mode", "http")
middleware = UnityInstanceMiddleware() middleware = UnityInstanceMiddleware()
ctx = DummyContext() ctx = DummyContext()
@ -74,15 +72,12 @@ def test_auto_selects_single_instance_via_stdio(monkeypatch):
plugin_hub.PluginHub = PluginHub plugin_hub.PluginHub = PluginHub
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub) monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
unity_transport = types.ModuleType("transport.unity_transport")
unity_transport._current_transport = lambda: "stdio"
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False) monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
assert ImportedPluginHub is plugin_hub.PluginHub assert ImportedPluginHub is plugin_hub.PluginHub
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "stdio") monkeypatch.setattr(config, "transport_mode", "stdio")
middleware = UnityInstanceMiddleware() middleware = UnityInstanceMiddleware()
ctx = DummyContext() ctx = DummyContext()
@ -118,9 +113,6 @@ def test_auto_select_handles_stdio_errors(monkeypatch):
plugin_hub.PluginHub = PluginHub plugin_hub.PluginHub = PluginHub
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub) monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
unity_transport = types.ModuleType("transport.unity_transport")
unity_transport._current_transport = lambda: "stdio"
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False) monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub

View File

@ -14,6 +14,7 @@ import pytest
from unittest.mock import AsyncMock, Mock, MagicMock, patch from unittest.mock import AsyncMock, Mock, MagicMock, patch
from fastmcp import Context from fastmcp import Context
from core.config import config
from transport.unity_instance_middleware import UnityInstanceMiddleware from transport.unity_instance_middleware import UnityInstanceMiddleware
from services.tools import get_unity_instance_from_context from services.tools import get_unity_instance_from_context
from services.tools.set_active_instance import set_active_instance as set_active_instance_tool from services.tools.set_active_instance import set_active_instance as set_active_instance_tool
@ -28,6 +29,7 @@ class TestInstanceRoutingBasics:
middleware = UnityInstanceMiddleware() middleware = UnityInstanceMiddleware()
ctx = Mock(spec=Context) ctx = Mock(spec=Context)
ctx.session_id = "test-session-1" ctx.session_id = "test-session-1"
ctx.client_id = "test-client-1"
# Set active instance # Set active instance
middleware.set_active_instance(ctx, "TestProject@abc123") middleware.set_active_instance(ctx, "TestProject@abc123")
@ -73,6 +75,7 @@ class TestInstanceRoutingBasics:
ctx = Mock(spec=Context) ctx = Mock(spec=Context)
ctx.session_id = None ctx.session_id = None
ctx.client_id = None ctx.client_id = None
ctx.get_state = Mock(return_value=None)
middleware.set_active_instance(ctx, "Project@global") middleware.set_active_instance(ctx, "Project@global")
assert middleware.get_active_instance(ctx) == "Project@global" assert middleware.get_active_instance(ctx) == "Project@global"
@ -170,7 +173,7 @@ class TestInstanceRoutingHTTP:
v: state_storage.__setitem__(k, v)) v: state_storage.__setitem__(k, v))
ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k)) ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k))
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http") monkeypatch.setattr(config, "transport_mode", "http")
fake_sessions = SessionList( fake_sessions = SessionList(
sessions={ sessions={
"sess-1": SessionDetails( "sess-1": SessionDetails(
@ -206,7 +209,7 @@ class TestInstanceRoutingHTTP:
v: state_storage.__setitem__(k, v)) v: state_storage.__setitem__(k, v))
ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k)) ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k))
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http") monkeypatch.setattr(config, "transport_mode", "http")
fake_sessions = SessionList( fake_sessions = SessionList(
sessions={ sessions={
"sess-99": SessionDetails( "sess-99": SessionDetails(
@ -238,7 +241,7 @@ class TestInstanceRoutingHTTP:
ctx = Mock(spec=Context) ctx = Mock(spec=Context)
ctx.session_id = "http-session-3" ctx.session_id = "http-session-3"
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http") monkeypatch.setattr(config, "transport_mode", "http")
fake_sessions = SessionList(sessions={}) fake_sessions = SessionList(sessions={})
monkeypatch.setattr( monkeypatch.setattr(
"services.tools.set_active_instance.PluginHub.get_sessions", "services.tools.set_active_instance.PluginHub.get_sessions",
@ -261,7 +264,7 @@ class TestInstanceRoutingHTTP:
ctx = Mock(spec=Context) ctx = Mock(spec=Context)
ctx.session_id = "http-session-4" ctx.session_id = "http-session-4"
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http") monkeypatch.setattr(config, "transport_mode", "http")
fake_sessions = SessionList( fake_sessions = SessionList(
sessions={ sessions={
"sess-a": SessionDetails(project="ProjA", hash="abc12345", unity_version="2022", connected_at="now"), "sess-a": SessionDetails(project="ProjA", hash="abc12345", unity_version="2022", connected_at="now"),

View File

@ -0,0 +1,172 @@
"""Tests for UnityInstanceMiddleware auth enforcement in remote-hosted mode."""
import asyncio
import sys
from unittest.mock import AsyncMock, Mock, patch
import pytest
from core.config import config
from tests.integration.test_helpers import DummyContext
class TestMiddlewareAuthEnforcement:
@pytest.mark.asyncio
async def test_remote_hosted_requires_user_id(self, monkeypatch):
"""_inject_unity_instance should raise RuntimeError when remote-hosted and no user_id."""
monkeypatch.setattr(config, "http_remote_hosted", True)
from transport.unity_instance_middleware import UnityInstanceMiddleware
middleware = UnityInstanceMiddleware()
# Mock _resolve_user_id to return None (no API key / failed validation)
monkeypatch.setattr(middleware, "_resolve_user_id",
AsyncMock(return_value=None))
ctx = DummyContext()
middleware_ctx = Mock()
middleware_ctx.fastmcp_context = ctx
with pytest.raises(RuntimeError, match="API key authentication required"):
await middleware._inject_unity_instance(middleware_ctx)
@pytest.mark.asyncio
async def test_sets_user_id_in_context_state(self, monkeypatch):
"""_inject_unity_instance should set user_id in ctx state when resolved."""
monkeypatch.setattr(config, "http_remote_hosted", True)
from transport.unity_instance_middleware import UnityInstanceMiddleware
middleware = UnityInstanceMiddleware()
monkeypatch.setattr(middleware, "_resolve_user_id",
AsyncMock(return_value="user-55"))
# We need PluginHub to be configured for the session resolution path
# But we don't need it to actually find a session for this test
from transport.plugin_hub import PluginHub
from transport.plugin_registry import PluginRegistry
registry = PluginRegistry()
loop = asyncio.get_running_loop()
PluginHub.configure(registry, loop)
ctx = DummyContext()
ctx.client_id = "client-1"
middleware_ctx = Mock()
middleware_ctx.fastmcp_context = ctx
# Set an active instance so the middleware doesn't try to auto-select
middleware.set_active_instance(ctx, "Proj@hash1")
# Register a matching session so resolution doesn't fail
await registry.register("s1", "Proj", "hash1", "2022", user_id="user-55")
await middleware._inject_unity_instance(middleware_ctx)
assert ctx.get_state("user_id") == "user-55"
class TestMiddlewareSessionKey:
def test_get_session_key_uses_user_id_fallback(self):
"""When no client_id, middleware should use user:$user_id as session key."""
from transport.unity_instance_middleware import UnityInstanceMiddleware
middleware = UnityInstanceMiddleware()
ctx = DummyContext()
# Simulate no client_id attribute
if hasattr(ctx, "client_id"):
delattr(ctx, "client_id")
ctx.set_state("user_id", "user-77")
key = middleware.get_session_key(ctx)
assert key == "user:user-77"
def test_get_session_key_prefers_client_id(self):
"""client_id should take precedence over user_id."""
from transport.unity_instance_middleware import UnityInstanceMiddleware
middleware = UnityInstanceMiddleware()
ctx = DummyContext()
ctx.client_id = "client-abc"
ctx.set_state("user_id", "user-77")
key = middleware.get_session_key(ctx)
assert key == "client-abc"
class TestAutoSelectDisabledRemoteHosted:
@pytest.mark.asyncio
async def test_auto_select_returns_none_in_remote_hosted(self, monkeypatch):
"""_maybe_autoselect_instance should return None in remote-hosted mode even with one session."""
monkeypatch.setattr(config, "http_remote_hosted", True)
monkeypatch.setattr(config, "transport_mode", "http")
# Re-import middleware to pick up the stubbed transport module
monkeypatch.delitem(
sys.modules, "transport.unity_instance_middleware", raising=False)
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as HubRef
# Configure PluginHub with one session so auto-select has something to find
from transport.plugin_registry import PluginRegistry
registry = PluginRegistry()
await registry.register("s1", "Proj", "h1", "2022", user_id="userA")
loop = asyncio.get_running_loop()
HubRef.configure(registry, loop)
middleware = UnityInstanceMiddleware()
ctx = DummyContext()
ctx.client_id = "client-1"
result = await middleware._maybe_autoselect_instance(ctx)
# Remote-hosted mode should NOT auto-select (early return at the transport check)
assert result is None
class TestHttpAuthBehavior:
@pytest.mark.asyncio
async def test_http_local_does_not_require_user_id(self, monkeypatch):
"""HTTP local mode should allow requests without user_id."""
monkeypatch.setattr(config, "http_remote_hosted", False)
monkeypatch.setattr(config, "transport_mode", "http")
from transport import unity_transport
async def fake_send_command_for_instance(*_args, **_kwargs):
return {"success": True, "data": {"ok": True}}
monkeypatch.setattr(
unity_transport.PluginHub,
"send_command_for_instance",
fake_send_command_for_instance,
)
async def _unused_send_fn(*_args, **_kwargs):
raise AssertionError("send_fn should not be used in HTTP mode")
result = await unity_transport.send_with_unity_instance(
_unused_send_fn, None, "ping", {}
)
assert result["success"] is True
assert result["data"] == {"ok": True}
@pytest.mark.asyncio
async def test_http_remote_requires_user_id(self, monkeypatch):
"""HTTP remote-hosted mode should reject requests without user_id."""
monkeypatch.setattr(config, "http_remote_hosted", True)
monkeypatch.setattr(config, "transport_mode", "http")
from transport import unity_transport
async def _unused_send_fn(*_args, **_kwargs):
raise AssertionError("send_fn should not be used in HTTP mode")
result = await unity_transport.send_with_unity_instance(
_unused_send_fn, None, "ping", {}
)
assert result["success"] is False
assert result["error"] == "auth_required"

View File

@ -0,0 +1,176 @@
"""Integration tests for multi-user session isolation in remote-hosted mode.
These tests compose PluginRegistry + PluginHub to verify that users
cannot see or interact with each other's Unity instances.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from core.config import config
from transport.plugin_hub import NoUnitySessionError, PluginHub
from transport.plugin_registry import PluginRegistry
@pytest.fixture(autouse=True)
def _reset_plugin_hub():
old_registry = PluginHub._registry
old_connections = PluginHub._connections.copy()
old_pending = PluginHub._pending.copy()
old_lock = PluginHub._lock
old_loop = PluginHub._loop
yield
PluginHub._registry = old_registry
PluginHub._connections = old_connections
PluginHub._pending = old_pending
PluginHub._lock = old_lock
PluginHub._loop = old_loop
async def _setup_two_user_registry():
"""Set up a registry with two users, each having Unity instances.
Returns the configured registry. Also configures PluginHub to use it.
"""
registry = PluginRegistry()
loop = asyncio.get_running_loop()
PluginHub.configure(registry, loop)
await registry.register("sess-A1", "ProjectAlpha", "hashA1", "2022.3", user_id="userA")
await registry.register("sess-B1", "ProjectBeta", "hashB1", "2022.3", user_id="userB")
await registry.register("sess-A2", "ProjectGamma", "hashA2", "2022.3", user_id="userA")
return registry
class TestMultiUserSessionFiltering:
@pytest.mark.asyncio
async def test_get_sessions_filters_by_user(self):
"""PluginHub.get_sessions(user_id=X) returns only X's sessions."""
await _setup_two_user_registry()
sessions_a = await PluginHub.get_sessions(user_id="userA")
assert len(sessions_a.sessions) == 2
project_names = {s.project for s in sessions_a.sessions.values()}
assert project_names == {"ProjectAlpha", "ProjectGamma"}
sessions_b = await PluginHub.get_sessions(user_id="userB")
assert len(sessions_b.sessions) == 1
assert next(iter(sessions_b.sessions.values())
).project == "ProjectBeta"
@pytest.mark.asyncio
async def test_get_sessions_no_filter_returns_all_in_local_mode(self):
"""In local mode, PluginHub.get_sessions() without user_id returns everything."""
await _setup_two_user_registry()
all_sessions = await PluginHub.get_sessions()
assert len(all_sessions.sessions) == 3
@pytest.mark.asyncio
async def test_get_sessions_no_filter_raises_in_remote_hosted(self, monkeypatch):
"""In remote-hosted mode, PluginHub.get_sessions() without user_id raises."""
monkeypatch.setattr(config, "http_remote_hosted", True)
await _setup_two_user_registry()
with pytest.raises(ValueError, match="requires user_id"):
await PluginHub.get_sessions()
class TestResolveSessionIdIsolation:
@pytest.mark.asyncio
async def test_resolve_session_for_own_hash(self, monkeypatch):
"""User A can resolve their own project hash."""
monkeypatch.setattr(config, "http_remote_hosted", True)
await _setup_two_user_registry()
session_id = await PluginHub._resolve_session_id("hashA1", user_id="userA")
assert session_id == "sess-A1"
@pytest.mark.asyncio
async def test_cannot_resolve_other_users_hash(self, monkeypatch):
"""User A cannot resolve User B's project hash."""
monkeypatch.setattr(config, "http_remote_hosted", True)
monkeypatch.setenv("UNITY_MCP_SESSION_RESOLVE_MAX_WAIT_S", "0.1")
await _setup_two_user_registry()
# userA tries to resolve userB's hash -> should not find it
with pytest.raises(NoUnitySessionError):
await PluginHub._resolve_session_id("hashB1", user_id="userA")
class TestInstanceListResourceIsolation:
@pytest.mark.asyncio
async def test_unity_instances_resource_filters_by_user(self, monkeypatch):
"""The unity_instances resource should pass user_id and return filtered results."""
monkeypatch.setattr(config, "http_remote_hosted", True)
monkeypatch.setattr(config, "transport_mode", "http")
await _setup_two_user_registry()
from services.resources.unity_instances import unity_instances
from tests.integration.test_helpers import DummyContext
ctx = DummyContext()
ctx.set_state("user_id", "userA")
result = await unity_instances(ctx)
assert result["success"] is True
assert result["instance_count"] == 2
instance_names = {i["name"] for i in result["instances"]}
assert instance_names == {"ProjectAlpha", "ProjectGamma"}
assert "ProjectBeta" not in instance_names
class TestSetActiveInstanceIsolation:
@pytest.mark.asyncio
async def test_set_active_instance_only_sees_own_sessions(self, monkeypatch):
"""set_active_instance should only offer sessions belonging to the current user."""
monkeypatch.setattr(config, "http_remote_hosted", True)
monkeypatch.setattr(config, "transport_mode", "http")
await _setup_two_user_registry()
from services.tools.set_active_instance import set_active_instance
from transport.unity_instance_middleware import UnityInstanceMiddleware
from tests.integration.test_helpers import DummyContext
middleware = UnityInstanceMiddleware()
monkeypatch.setattr(
"services.tools.set_active_instance.get_unity_instance_middleware",
lambda: middleware,
)
ctx = DummyContext()
ctx.set_state("user_id", "userA")
result = await set_active_instance(ctx, "ProjectAlpha@hashA1")
assert result["success"] is True
assert middleware.get_active_instance(ctx) == "ProjectAlpha@hashA1"
@pytest.mark.asyncio
async def test_set_active_instance_rejects_other_users_instance(self, monkeypatch):
"""set_active_instance should not find another user's instance."""
monkeypatch.setattr(config, "http_remote_hosted", True)
monkeypatch.setattr(config, "transport_mode", "http")
await _setup_two_user_registry()
from services.tools.set_active_instance import set_active_instance
from transport.unity_instance_middleware import UnityInstanceMiddleware
from tests.integration.test_helpers import DummyContext
middleware = UnityInstanceMiddleware()
monkeypatch.setattr(
"services.tools.set_active_instance.get_unity_instance_middleware",
lambda: middleware,
)
ctx = DummyContext()
ctx.set_state("user_id", "userA")
# UserA tries to select UserB's instance -> should fail
result = await set_active_instance(ctx, "ProjectBeta@hashB1")
assert result["success"] is False

View File

@ -0,0 +1,183 @@
"""Tests for PluginHub WebSocket API key authentication gate."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from core.config import config
from core.constants import API_KEY_HEADER
from services.api_key_service import ApiKeyService, ValidationResult
from transport.plugin_hub import PluginHub
from transport.plugin_registry import PluginRegistry
@pytest.fixture(autouse=True)
def _reset_api_key_singleton():
ApiKeyService._instance = None
yield
ApiKeyService._instance = None
@pytest.fixture(autouse=True)
def _reset_plugin_hub():
"""Ensure PluginHub class-level state doesn't leak between tests."""
old_registry = PluginHub._registry
old_connections = PluginHub._connections.copy()
old_pending = PluginHub._pending.copy()
old_lock = PluginHub._lock
old_loop = PluginHub._loop
yield
PluginHub._registry = old_registry
PluginHub._connections = old_connections
PluginHub._pending = old_pending
PluginHub._lock = old_lock
PluginHub._loop = old_loop
def _make_mock_websocket(headers=None, state_attrs=None):
"""Create a mock WebSocket with configurable headers and state."""
ws = AsyncMock()
ws.headers = headers or {}
ws.state = SimpleNamespace(**(state_attrs or {}))
ws.accept = AsyncMock()
ws.close = AsyncMock()
ws.send_json = AsyncMock()
return ws
def _make_hub():
"""Create a PluginHub instance with a minimal ASGI scope."""
scope = {"type": "websocket"}
return PluginHub(scope, receive=AsyncMock(), send=AsyncMock())
def _init_api_key_service(validate_result=None):
"""Initialize ApiKeyService with a mocked validate method."""
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
if validate_result is not None:
svc.validate = AsyncMock(return_value=validate_result)
return svc
class TestWebSocketAuthGate:
@pytest.mark.asyncio
async def test_no_api_key_remote_hosted_rejected(self, monkeypatch):
"""WebSocket without API key in remote-hosted mode -> close 4401."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(ValidationResult(valid=True, user_id="u1"))
ws = _make_mock_websocket(headers={}) # No X-API-Key header
hub = _make_hub()
await hub.on_connect(ws)
ws.close.assert_called_once_with(code=4401, reason="API key required")
ws.accept.assert_not_called()
@pytest.mark.asyncio
async def test_invalid_api_key_rejected(self, monkeypatch):
"""WebSocket with invalid API key -> close 4403."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(ValidationResult(
valid=False, error="Invalid API key"))
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-bad-key"})
hub = _make_hub()
await hub.on_connect(ws)
ws.close.assert_called_once_with(code=4403, reason="Invalid API key")
ws.accept.assert_not_called()
@pytest.mark.asyncio
async def test_valid_api_key_accepted(self, monkeypatch):
"""WebSocket with valid API key -> accepted, user_id stored in state."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(
ValidationResult(valid=True, user_id="user-42",
metadata={"plan": "pro"})
)
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-valid-key"})
hub = _make_hub()
await hub.on_connect(ws)
ws.accept.assert_called_once()
ws.close.assert_not_called()
assert ws.state.user_id == "user-42"
assert ws.state.api_key_metadata == {"plan": "pro"}
# Should have sent welcome message
ws.send_json.assert_called_once()
@pytest.mark.asyncio
async def test_auth_service_unavailable_close_1013(self, monkeypatch):
"""Auth service error with 'unavailable' -> close 1013 (try again later)."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(
ValidationResult(
valid=False, error="Auth service unavailable", cacheable=False)
)
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-some-key"})
hub = _make_hub()
await hub.on_connect(ws)
ws.close.assert_called_once_with(code=1013, reason="Try again later")
ws.accept.assert_not_called()
@pytest.mark.asyncio
async def test_not_remote_hosted_accepts_without_key(self, monkeypatch):
"""When not remote-hosted, WebSocket accepted without API key."""
monkeypatch.setattr(config, "http_remote_hosted", False)
ws = _make_mock_websocket(headers={})
hub = _make_hub()
await hub.on_connect(ws)
ws.accept.assert_called_once()
ws.close.assert_not_called()
class TestUserIdFlowsToRegistration:
@pytest.mark.asyncio
async def test_user_id_passed_to_registry_on_register(self, monkeypatch):
"""After valid auth, the register message should pass user_id to registry."""
monkeypatch.setattr(config, "http_remote_hosted", True)
_init_api_key_service(
ValidationResult(valid=True, user_id="user-99")
)
registry = PluginRegistry()
loop = asyncio.get_running_loop()
PluginHub.configure(registry, loop)
# Simulate full flow: connect, then register
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-valid-key"})
hub = _make_hub()
await hub.on_connect(ws)
assert ws.state.user_id == "user-99"
# Simulate register message
register_data = {
"type": "register",
"project_name": "TestProject",
"project_hash": "abc123",
"unity_version": "2022.3",
}
await hub.on_receive(ws, register_data)
# Verify registry has the user_id
sessions = await registry.list_sessions(user_id="user-99")
assert len(sessions) == 1
session = next(iter(sessions.values()))
assert session.user_id == "user-99"
assert session.project_name == "TestProject"
assert session.project_hash == "abc123"

View File

@ -0,0 +1,112 @@
"""Tests for PluginRegistry user-scoped session isolation (remote-hosted mode)."""
import pytest
from core.config import config
from transport.plugin_registry import PluginRegistry
class TestRegistryUserIsolation:
@pytest.mark.asyncio
async def test_register_with_user_id_stores_composite_key(self):
registry = PluginRegistry()
session = await registry.register(
"sess-1", "MyProject", "hash1", "2022.3", user_id="user-A"
)
assert session.user_id == "user-A"
assert ("user-A", "hash1") in registry._user_hash_to_session
assert registry._user_hash_to_session[("user-A", "hash1")] == "sess-1"
@pytest.mark.asyncio
async def test_get_session_id_by_hash(self):
registry = PluginRegistry()
await registry.register("sess-1", "Proj", "h1", "2022", user_id="uA")
found = await registry.get_session_id_by_hash("h1", "uA")
assert found == "sess-1"
# Different user, same hash -> not found
not_found = await registry.get_session_id_by_hash("h1", "uB")
assert not_found is None
@pytest.mark.asyncio
async def test_cross_user_isolation_same_hash(self):
"""Two users registering with the same project_hash get independent sessions."""
registry = PluginRegistry()
sess_a = await registry.register("sA", "Proj", "hash1", "2022", user_id="userA")
sess_b = await registry.register("sB", "Proj", "hash1", "2022", user_id="userB")
assert sess_a.session_id == "sA"
assert sess_b.session_id == "sB"
# Each user resolves to their own session
assert await registry.get_session_id_by_hash("hash1", "userA") == "sA"
assert await registry.get_session_id_by_hash("hash1", "userB") == "sB"
# Both sessions exist
all_sessions = await registry.list_sessions()
assert len(all_sessions) == 2
@pytest.mark.asyncio
async def test_list_sessions_filtered_by_user(self):
registry = PluginRegistry()
await registry.register("s1", "ProjA", "hA", "2022", user_id="userA")
await registry.register("s2", "ProjB", "hB", "2022", user_id="userB")
await registry.register("s3", "ProjC", "hC", "2022", user_id="userA")
user_a_sessions = await registry.list_sessions(user_id="userA")
assert len(user_a_sessions) == 2
assert "s1" in user_a_sessions
assert "s3" in user_a_sessions
user_b_sessions = await registry.list_sessions(user_id="userB")
assert len(user_b_sessions) == 1
assert "s2" in user_b_sessions
@pytest.mark.asyncio
async def test_list_sessions_no_filter_returns_all_in_local_mode(self):
"""In local mode (not remote-hosted), list_sessions(user_id=None) returns all."""
registry = PluginRegistry()
await registry.register("s1", "P1", "h1", "2022", user_id="uA")
await registry.register("s2", "P2", "h2", "2022", user_id="uB")
await registry.register("s3", "P3", "h3", "2022") # local mode, no user_id
all_sessions = await registry.list_sessions(user_id=None)
assert len(all_sessions) == 3
@pytest.mark.asyncio
async def test_list_sessions_no_filter_raises_in_remote_hosted(self, monkeypatch):
"""In remote-hosted mode, list_sessions(user_id=None) raises ValueError."""
monkeypatch.setattr(config, "http_remote_hosted", True)
registry = PluginRegistry()
await registry.register("s1", "P1", "h1", "2022", user_id="uA")
with pytest.raises(ValueError, match="requires user_id"):
await registry.list_sessions(user_id=None)
@pytest.mark.asyncio
async def test_unregister_cleans_user_scoped_mapping(self):
registry = PluginRegistry()
await registry.register("s1", "Proj", "h1", "2022", user_id="uA")
assert ("uA", "h1") in registry._user_hash_to_session
await registry.unregister("s1")
assert ("uA", "h1") not in registry._user_hash_to_session
assert "s1" not in (await registry.list_sessions())
@pytest.mark.asyncio
async def test_reconnect_replaces_previous_session(self):
"""Same (user_id, hash) re-registered evicts old session, stores new one."""
registry = PluginRegistry()
await registry.register("old-sess", "Proj", "h1", "2022", user_id="uA")
assert await registry.get_session_id_by_hash("h1", "uA") == "old-sess"
await registry.register("new-sess", "Proj", "h1", "2022", user_id="uA")
assert await registry.get_session_id_by_hash("h1", "uA") == "new-sess"
# Old session should be evicted
all_sessions = await registry.list_sessions()
assert "old-sess" not in all_sessions
assert "new-sess" in all_sessions

View File

@ -0,0 +1,114 @@
"""Tests for _resolve_user_id_from_request in unity_transport.py."""
import sys
import types
from unittest.mock import AsyncMock
import pytest
from core.config import config
from services.api_key_service import ApiKeyService, ValidationResult
@pytest.fixture(autouse=True)
def _reset_api_key_singleton():
ApiKeyService._instance = None
yield
ApiKeyService._instance = None
class TestResolveUserIdFromRequest:
@pytest.mark.asyncio
async def test_returns_none_when_not_remote_hosted(self, monkeypatch):
monkeypatch.setattr(config, "http_remote_hosted", False)
from transport.unity_transport import _resolve_user_id_from_request
result = await _resolve_user_id_from_request()
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_service_not_initialized(self, monkeypatch):
monkeypatch.setattr(config, "http_remote_hosted", True)
# ApiKeyService._instance is None (from fixture)
from transport.unity_transport import _resolve_user_id_from_request
result = await _resolve_user_id_from_request()
assert result is None
@pytest.mark.asyncio
async def test_returns_user_id_for_valid_key(self, monkeypatch):
monkeypatch.setattr(config, "http_remote_hosted", True)
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
svc.validate = AsyncMock(
return_value=ValidationResult(valid=True, user_id="user-123")
)
# Stub the fastmcp dependency that provides HTTP headers
deps_mod = types.ModuleType("fastmcp.server.dependencies")
deps_mod.get_http_headers = lambda include_all=False: {
"x-api-key": "sk-valid"}
monkeypatch.setitem(
sys.modules, "fastmcp.server.dependencies", deps_mod)
from transport.unity_transport import _resolve_user_id_from_request
result = await _resolve_user_id_from_request()
assert result == "user-123"
svc.validate.assert_called_once_with("sk-valid")
@pytest.mark.asyncio
async def test_returns_none_for_invalid_key(self, monkeypatch):
monkeypatch.setattr(config, "http_remote_hosted", True)
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
svc.validate = AsyncMock(
return_value=ValidationResult(valid=False, error="bad key")
)
deps_mod = types.ModuleType("fastmcp.server.dependencies")
deps_mod.get_http_headers = lambda include_all=False: {
"x-api-key": "sk-bad"}
monkeypatch.setitem(
sys.modules, "fastmcp.server.dependencies", deps_mod)
from transport.unity_transport import _resolve_user_id_from_request
result = await _resolve_user_id_from_request()
assert result is None
@pytest.mark.asyncio
async def test_returns_none_on_exception(self, monkeypatch):
monkeypatch.setattr(config, "http_remote_hosted", True)
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
svc.validate = AsyncMock(side_effect=RuntimeError("boom"))
deps_mod = types.ModuleType("fastmcp.server.dependencies")
deps_mod.get_http_headers = lambda include_all=False: {
"x-api-key": "sk-err"}
monkeypatch.setitem(
sys.modules, "fastmcp.server.dependencies", deps_mod)
from transport.unity_transport import _resolve_user_id_from_request
result = await _resolve_user_id_from_request()
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_no_api_key_header(self, monkeypatch):
monkeypatch.setattr(config, "http_remote_hosted", True)
ApiKeyService(validation_url="https://auth.example.com/validate")
deps_mod = types.ModuleType("fastmcp.server.dependencies")
deps_mod.get_http_headers = lambda include_all=False: {} # No x-api-key
monkeypatch.setitem(
sys.modules, "fastmcp.server.dependencies", deps_mod)
from transport.unity_transport import _resolve_user_id_from_request
result = await _resolve_user_id_from_request()
assert result is None

View File

@ -1,3 +1,4 @@
import logging
import types import types
import threading import threading
import time import time
@ -7,7 +8,13 @@ import core.telemetry as telemetry
def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog): def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
caplog.set_level("DEBUG") # Directly attach caplog's handler to the telemetry logger so that
# earlier tests calling logging.basicConfig() can't steal the records
# via a root handler before caplog sees them.
tel_logger = logging.getLogger("unity-mcp-telemetry")
tel_logger.addHandler(caplog.handler)
try:
caplog.set_level("DEBUG", logger="unity-mcp-telemetry")
collector = telemetry.TelemetryCollector() collector = telemetry.TelemetryCollector()
# Force-enable telemetry regardless of env settings from conftest # Force-enable telemetry regardless of env settings from conftest
@ -18,8 +25,9 @@ def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
# Replace queue with tiny one to trigger backpressure quickly # Replace queue with tiny one to trigger backpressure quickly
small_q = q.Queue(maxsize=2) small_q = q.Queue(maxsize=2)
collector._queue = small_q collector._queue = small_q
# Give the worker a moment to switch queues # Give the worker time to finish processing the seeded item and
time.sleep(0.02) # re-enter _queue.get() on the new small queue
time.sleep(0.2)
# Make sends slow to build backlog and exercise worker # Make sends slow to build backlog and exercise worker
def slow_send(self, rec): def slow_send(self, rec):
@ -52,3 +60,6 @@ def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
worker_threads = [ worker_threads = [
t for t in threading.enumerate() if t is collector._worker] t for t in threading.enumerate() if t is collector._worker]
assert len(worker_threads) == 1 assert len(worker_threads) == 1
finally:
if caplog.handler in tel_logger.handlers:
tel_logger.removeHandler(caplog.handler)

View File

@ -0,0 +1,80 @@
"""End-to-end-ish smoke tests for transport routing paths."""
from __future__ import annotations
import pytest
from core.config import config
from transport import unity_transport
@pytest.mark.asyncio
async def test_http_local_smoke(monkeypatch):
"""HTTP local should route through PluginHub without requiring user_id."""
monkeypatch.setattr(config, "transport_mode", "http")
monkeypatch.setattr(config, "http_remote_hosted", False)
async def fake_send_command_for_instance(_instance, _command, _params, **_kwargs):
return {"status": "success", "result": {"message": "ok", "data": {"via": "http"}}}
monkeypatch.setattr(
unity_transport.PluginHub,
"send_command_for_instance",
fake_send_command_for_instance,
)
async def _unused_send_fn(*_args, **_kwargs):
raise AssertionError("send_fn should not be used in HTTP mode")
result = await unity_transport.send_with_unity_instance(
_unused_send_fn, None, "ping", {}
)
assert result["success"] is True
assert result["data"] == {"via": "http"}
@pytest.mark.asyncio
async def test_http_remote_smoke(monkeypatch):
"""HTTP remote-hosted should route through PluginHub when user_id is provided."""
monkeypatch.setattr(config, "transport_mode", "http")
monkeypatch.setattr(config, "http_remote_hosted", True)
async def fake_send_command_for_instance(_instance, _command, _params, **_kwargs):
return {"status": "success", "result": {"data": {"via": "http-remote"}}}
monkeypatch.setattr(
unity_transport.PluginHub,
"send_command_for_instance",
fake_send_command_for_instance,
)
async def _unused_send_fn(*_args, **_kwargs):
raise AssertionError("send_fn should not be used in HTTP mode")
result = await unity_transport.send_with_unity_instance(
_unused_send_fn, None, "ping", {}, user_id="user-1"
)
assert result["success"] is True
assert result["data"] == {"via": "http-remote"}
@pytest.mark.asyncio
async def test_stdio_smoke(monkeypatch):
"""Stdio transport should call the legacy send fn with instance_id."""
monkeypatch.setattr(config, "transport_mode", "stdio")
monkeypatch.setattr(config, "http_remote_hosted", False)
async def fake_send_fn(command_type, params, *, instance_id=None, **_kwargs):
return {
"success": True,
"data": {"via": "stdio", "command": command_type, "instance": instance_id, "params": params},
}
result = await unity_transport.send_with_unity_instance(
fake_send_fn, "Project@abcd1234", "ping", {"x": 1}
)
assert result["success"] is True
assert result["data"]["via"] == "stdio"
assert result["data"]["instance"] == "Project@abcd1234"

View File

@ -65,21 +65,6 @@ def mock_instances_response():
} }
@pytest.fixture
def mock_sessions_response():
"""Mock plugin sessions response (legacy format)."""
return {
"sessions": {
"test-session-123": {
"project": "TestProject",
"hash": "abc123def456",
"unity_version": "2022.3.10f1",
"connected_at": "2024-01-01T00:00:00Z",
}
}
}
# ============================================================================= # =============================================================================
# Config Tests # Config Tests
# ============================================================================= # =============================================================================
@ -246,23 +231,6 @@ class TestConnection:
with pytest.raises(UnityConnectionError): with pytest.raises(UnityConnectionError):
await send_command("test_command", {}) await send_command("test_command", {})
@pytest.mark.asyncio
async def test_list_instances_from_sessions(self, mock_sessions_response):
"""Test listing instances from /plugin/sessions endpoint."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_sessions_response
with patch("httpx.AsyncClient") as mock_client:
# First call (api/instances) returns 404, second (plugin/sessions) succeeds
mock_get = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.get = mock_get
result = await list_unity_instances()
assert result["success"] is True
assert len(result["instances"]) == 1
assert result["instances"][0]["project"] == "TestProject"
# ============================================================================= # =============================================================================
# CLI Command Tests # CLI Command Tests

View File

@ -21,7 +21,7 @@ import uuid
from transport.unity_instance_middleware import UnityInstanceMiddleware, get_unity_instance_middleware, set_unity_instance_middleware from transport.unity_instance_middleware import UnityInstanceMiddleware, get_unity_instance_middleware, set_unity_instance_middleware
from transport.plugin_registry import PluginRegistry, PluginSession from transport.plugin_registry import PluginRegistry, PluginSession
from transport.plugin_hub import PluginHub, NoUnitySessionError, PluginDisconnectedError from transport.plugin_hub import PluginHub, NoUnitySessionError, InstanceSelectionRequiredError, PluginDisconnectedError
from transport.models import ( from transport.models import (
RegisterMessage, RegisterMessage,
RegisterToolsMessage, RegisterToolsMessage,
@ -775,7 +775,7 @@ class TestSessionResolution:
unity_version="2023.2" unity_version="2023.2"
) )
with pytest.raises(RuntimeError, match="Multiple Unity instances"): with pytest.raises(InstanceSelectionRequiredError, match="Multiple Unity instances"):
await PluginHub._resolve_session_id(None) await PluginHub._resolve_session_id(None)
# Cleanup # Cleanup

View File

@ -0,0 +1,261 @@
# Remote Server API Key Authentication
When running the MCP for Unity server as a shared remote service, API key authentication ensures that only authorized users can access the server and that each user's Unity sessions are isolated from one another.
This guide covers how to configure, deploy, and use the feature.
## Prerequisites
### External Auth Service
You need an external HTTP endpoint that validates API keys. The server delegates all key validation to this endpoint rather than managing keys itself.
The endpoint must:
- Accept `POST` requests with a JSON body: `{"api_key": "<key>"}`
- Return a JSON response indicating validity and the associated user identity
- Be reachable from the MCP server over the network
See [Validation Contract](#validation-contract) for the full request/response specification.
### Transport Mode
API key authentication is only available when running with HTTP transport (`--transport http`). It has no effect in stdio mode.
## Server Configuration
### CLI Arguments
| Argument | Environment Variable | Default | Description |
| -------- | -------------------- | ------- | ----------- |
| `--http-remote-hosted` | `UNITY_MCP_HTTP_REMOTE_HOSTED` | `false` | Enable remote-hosted mode. Requires API key auth. |
| `--api-key-validation-url URL` | `UNITY_MCP_API_KEY_VALIDATION_URL` | None | External endpoint to validate API keys (required). |
| `--api-key-login-url URL` | `UNITY_MCP_API_KEY_LOGIN_URL` | None | URL where users can obtain or manage API keys. |
| `--api-key-cache-ttl SECONDS` | `UNITY_MCP_API_KEY_CACHE_TTL` | `300` | How long validated keys are cached (seconds). |
| `--api-key-service-token-header HEADER` | `UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER` | None | Header name for server-to-auth-service authentication. |
| `--api-key-service-token TOKEN` | `UNITY_MCP_API_KEY_SERVICE_TOKEN` | None | Token value sent to the auth service for server authentication. |
Environment variables take effect when the corresponding CLI argument is not provided. For boolean flags, set the env var to `true`, `1`, or `yes`.
### Startup Validation
The server validates its configuration at startup:
- If `--http-remote-hosted` is set but `--api-key-validation-url` is not provided (and the env var is also unset), the server logs an error and exits with code 1.
### Example
```bash
python -m src.main \
--transport http \
--http-host 0.0.0.0 \
--http-port 8080 \
--http-remote-hosted \
--api-key-validation-url https://auth.example.com/api/validate-key \
--api-key-login-url https://app.example.com/api-keys \
--api-key-cache-ttl 120
```
Or using environment variables:
```bash
export UNITY_MCP_TRANSPORT=http
export UNITY_MCP_HTTP_HOST=0.0.0.0
export UNITY_MCP_HTTP_PORT=8080
export UNITY_MCP_HTTP_REMOTE_HOSTED=true
export UNITY_MCP_API_KEY_VALIDATION_URL=https://auth.example.com/api/validate-key
export UNITY_MCP_API_KEY_LOGIN_URL=https://app.example.com/api-keys
python -m src.main
```
### Service Token (Optional)
If your auth service requires the MCP server to authenticate itself (server-to-server auth), configure a service token:
```bash
--api-key-service-token-header X-Service-Token \
--api-key-service-token "your-server-secret"
```
This adds the specified header to every validation request sent to the auth endpoint.
We strongly recommend using this feature because it ensures that the entity requesting validation is the MCP server itself, not an imposter.
## Unity Plugin Setup
When connecting to a remote-hosted server, Unity users need to provide their API key:
1. Open the MCP for Unity window in the Unity Editor.
2. Select HTTP Remote as the connection mode.
3. Enter the API key in the API Key field. The key is stored in `EditorPrefs` (per-machine, not source-controlled).
4. Click **Get API Key** to open the login URL in a browser if you need a new key. This fetches the URL from the server's `/api/auth/login-url` endpoint.
The API key is a one-time entry per machine. It persists across Unity sessions until explicitly cleared.
## MCP Client Configuration
When an API key is configured, the Unity plugin's MCP client configurators automatically include the `X-API-Key` header in generated configuration files.
Example generated config for **Cursor** (`~/.cursor/mcp.json`):
```json
{
"mcpServers": {
"mcp-for-unity": {
"url": "http://remote-server:8080/mcp",
"headers": {
"X-API-Key": "<your-api-key>"
}
}
}
}
```
Example for **Claude Code** (CLI):
```bash
claude mcp add --transport http mcp-for-unity http://remote-server:8080/mcp \
--header "X-API-Key: <your-api-key>"
```
Similar header injection works for VS Code, Windsurf, Cline, and other supported MCP clients.
## Behaviour Changes in Remote-Hosted Mode
Enabling `--http-remote-hosted` changes several server behaviours compared to the default local mode:
### Authentication Enforcement
All MCP tool and resource calls require a valid API key. The `X-API-Key` header must be present on every HTTP request to the `/mcp` endpoint. If the key is missing or invalid, the middleware raises a `RuntimeError` that surfaces as an MCP error response.
### WebSocket Auth Gate
Unity plugins connecting via WebSocket (`/hub/plugin`) are validated during the handshake:
| Scenario | WebSocket Close Code | Reason |
| -------- | -------------------- | ------ |
| No API key header | `4401` | API key required |
| Invalid API key | `4403` | Invalid API key |
| Auth service unavailable | `1013` | Try again later |
| Valid API key | Connection accepted | user_id stored in connection state |
### Session Isolation
Each user can only see and interact with their own Unity instances. When User A calls `set_active_instance` or lists instances, they only see Unity editors that connected with User A's API key. User B's sessions are invisible to User A.
### Auto-Select Disabled
In local mode, the server automatically selects the sole connected Unity instance. In remote-hosted mode, this auto-selection is disabled. Users must explicitly call `set_active_instance` with a `Name@hash` from the `mcpforunity://instances` resource.
### CLI Routes Disabled
The following REST endpoints are disabled in remote-hosted mode to prevent unauthenticated access:
- `POST /api/command`
- `GET /api/instances`
- `GET /api/custom-tools`
### Endpoints Always Available
These endpoints remain accessible regardless of auth:
| Endpoint | Method | Purpose |
| -------- | ------ | ------- |
| `/health` | GET | Health check for load balancers and monitoring |
| `/api/auth/login-url` | GET | Returns the login URL for API key management |
## Validation Contract
### Request
```http
POST <api-key-validation-url>
Content-Type: application/json
{
"api_key": "<the-api-key>"
}
```
If a service token is configured, an additional header is sent:
```http
<service-token-header>: <service-token-value>
```
### Response (Valid Key)
```json
{
"valid": true,
"user_id": "user-abc-123",
"metadata": {}
}
```
- `valid` (bool, required): Must be `true`.
- `user_id` (string, required): Stable identifier for the user. Used for session isolation.
- `metadata` (object, optional): Arbitrary metadata stored alongside the validation result.
### Response (Invalid Key)
```json
{
"valid": false,
"error": "API key expired"
}
```
- `valid` (bool, required): Must be `false`.
- `error` (string, optional): Human-readable reason.
### Response (HTTP 401)
A `401` status code is also treated as an invalid key (no body parsing required).
### Timeouts and Retries
- Request timeout: 5 seconds
- Retries: 1 (with 100ms backoff)
- Failure mode: deny by default (treated as invalid on any error)
Transient failures (5xx, timeouts, network errors) are **not cached**, so subsequent requests will retry the auth service.
## Error Reference
| Context | Condition | Response |
| ------- | --------- | -------- |
| MCP tool/resource | Missing API key (remote-hosted) | `RuntimeError` → MCP `isError: true` |
| MCP tool/resource | Invalid API key | `RuntimeError` → MCP `isError: true` |
| WebSocket connect | Missing API key | Close `4401` "API key required" |
| WebSocket connect | Invalid API key | Close `4403` "Invalid API key" |
| WebSocket connect | Auth service down | Close `1013` "Try again later" |
| `/api/auth/login-url` | Login URL not configured | HTTP `404` with admin guidance message |
| Server startup | Remote-hosted without validation URL | `SystemExit(1)` |
## Troubleshooting
### "API key authentication required" error on every tool call
The server is in remote-hosted mode but no API key is being sent. Ensure the MCP client configuration includes the `X-API-Key` header, or set it in the Unity plugin's connection settings.
### Server exits immediately with code 1
The `--http-remote-hosted` flag requires `--api-key-validation-url`. Provide the URL via CLI argument or `UNITY_MCP_API_KEY_VALIDATION_URL` environment variable.
### WebSocket connection closes with 4401
The Unity plugin is not sending an API key. Enter the key in the MCP for Unity window's connection settings.
### WebSocket connection closes with 1013
The external auth service is unreachable. Check network connectivity between the MCP server and the validation URL. The Unity plugin can retry the connection.
### User cannot see their Unity instance
Session isolation is active. The Unity editor and the MCP client must use API keys that resolve to the same `user_id`. Verify that the Unity plugin's WebSocket connection and the MCP client's HTTP requests use the same API key.
### Stale auth after key rotation
Validated keys are cached for `--api-key-cache-ttl` seconds (default: 300). After rotating or revoking a key, there is a delay equal to the TTL before the old key stops working. Lower the TTL for faster revocation at the cost of more frequent validation requests.

View File

@ -0,0 +1,363 @@
# Remote Server Auth: Architecture
This document describes the internal design of the API key authentication system used when the MCP for Unity server runs in remote-hosted mode. It is intended for contributors and maintainers.
## Overview
```
MCP Client MCP Server External Auth
(Cursor, etc.) (Python) Service
| | |
| X-API-Key: abc123 | |
| POST /mcp (tool call) | |
|-------------------------->| |
| | |
| UnityInstanceMiddleware.on_call_tool |
| | |
| _resolve_user_id() |
| | |
| | POST /validate |
| | {"api_key": "abc123"} |
| |------------------------------>|
| | |
| | {"valid":true, |
| | "user_id":"user-42"} |
| |<------------------------------|
| | |
| Cache result (TTL) |
| | |
| ctx.set_state("user_id", "user-42") |
| ctx.set_state("unity_instance", "Proj@hash") |
| | |
| PluginHub.send_command_for_instance |
| (user_id scoped session lookup) |
| | |
| Tool result | |
|<--------------------------| |
Unity Plugin MCP Server External Auth
(C# WebSocket) (Python) Service
| | |
| WS /hub/plugin | |
| X-API-Key: abc123 | |
|-------------------------->| |
| | |
| PluginHub.on_connect |
| | POST /validate |
| |------------------------------>|
| | {"valid":true, ...} |
| |<------------------------------|
| | |
| accept() | |
| websocket.state.user_id = "user-42" |
|<--------------------------| |
| | |
| {"type":"register", ...} | |
|-------------------------->| |
| | |
| PluginRegistry.register( |
| ..., user_id="user-42") |
| _user_hash_to_session[("user-42","hash")] = sid |
| | |
| {"type":"registered"} | |
|<--------------------------| |
```
## Components
### ApiKeyService
**File:** `Server/src/services/api_key_service.py`
Singleton service that validates API keys against an external HTTP endpoint.
- **Singleton access:** `ApiKeyService.get_instance()` / `ApiKeyService.is_initialized()`
- **Initialization:** Constructed in `create_mcp_server()` when `config.http_remote_hosted` and `config.api_key_validation_url` are both set.
- **Validation:** `async validate(api_key) -> ValidationResult`
- **Caching:** In-memory dict keyed by raw API key. Entries store `(valid, user_id, metadata, expires_at)`.
- **Retry:** 1 retry with 100ms backoff on timeouts and connection errors.
- **Fail-closed:** Any unrecoverable error returns `ValidationResult(valid=False)`.
### PluginHub (WebSocket Auth Gate)
**File:** `Server/src/transport/plugin_hub.py`
The `on_connect` method validates the API key from the WebSocket handshake headers before accepting the connection.
- Reads `X-API-Key` from `websocket.headers`
- Validates via `ApiKeyService.validate()`
- Stores `user_id` and `api_key_metadata` on `websocket.state` for use during registration
- Rejects with close codes: `4401` (missing), `4403` (invalid), `1013` (service unavailable)
The `_handle_register` method reads `websocket.state.user_id` and passes it to `PluginRegistry.register()`.
The `get_sessions(user_id=None)` and `_resolve_session_id(unity_instance, user_id=None)` methods accept an optional `user_id` to scope session queries in remote-hosted mode.
### PluginRegistry (Dual-Index Session Storage)
**File:** `Server/src/transport/plugin_registry.py`
In-memory registry of connected Unity plugin sessions. Maintains two parallel index maps:
| Index | Key | Used In |
|-------|-----|---------|
| `_hash_to_session` | `project_hash -> session_id` | Local mode |
| `_user_hash_to_session` | `(user_id, project_hash) -> session_id` | Remote-hosted mode |
Both indexes are updated during `register()` and cleaned up during `unregister()`.
Key methods:
- `register(session_id, project_name, project_hash, unity_version, user_id=None)` - Registers a session and updates the appropriate index. If an existing session claims the same key, it is evicted.
- `get_session_id_by_hash(project_hash)` - Local-mode lookup.
- `get_session_id_by_hash(project_hash, user_id)` - Remote-mode lookup.
- `list_sessions(user_id=None)` - Returns sessions filtered by user. Raises `ValueError` if `user_id` is `None` while `config.http_remote_hosted` is `True`, preventing accidental cross-user leaks.
### UnityInstanceMiddleware
**File:** `Server/src/transport/unity_instance_middleware.py`
FastMCP middleware that intercepts all tool and resource calls to inject the active Unity instance and user identity into the request-scoped context.
Entry points:
- `on_call_tool(context, call_next)` - Intercepts tool calls.
- `on_read_resource(context, call_next)` - Intercepts resource reads.
Both delegate to `_inject_unity_instance(context)`, which:
1. Calls `_resolve_user_id()` to extract the user identity from the HTTP request.
2. If remote-hosted mode is active and no `user_id` is resolved, raises `RuntimeError` (surfaces as MCP error).
3. Sets `ctx.set_state("user_id", user_id)`.
4. Looks up or auto-selects the active Unity instance.
5. Sets `ctx.set_state("unity_instance", active_instance)`.
### _resolve_user_id_from_request
**File:** `Server/src/transport/unity_transport.py`
Extracts the `user_id` from the current HTTP request's `X-API-Key` header.
```
_resolve_user_id_from_request()
-> if not config.http_remote_hosted: return None
-> if not ApiKeyService.is_initialized(): return None
-> get_http_headers() from FastMCP dependencies
-> extract "x-api-key" header
-> ApiKeyService.validate(api_key)
-> return result.user_id if valid, else None
```
The middleware calls this indirectly through `_resolve_user_id()`, which adds an early return when not in remote-hosted mode (avoiding the import of FastMCP internals in local mode).
## Request Lifecycle
A complete authenticated MCP tool call follows this path:
1. **HTTP request arrives** at `/mcp` with `X-API-Key: <key>` header.
2. **FastMCP dispatches** the MCP tool call through its middleware chain.
3. **`UnityInstanceMiddleware.on_call_tool`** is invoked.
4. **`_inject_unity_instance`** runs:
- Calls `_resolve_user_id()`, which calls `_resolve_user_id_from_request()`.
- The request function imports `get_http_headers` from FastMCP and reads the `x-api-key` header.
- `ApiKeyService.validate()` checks the cache or calls the external auth endpoint.
- If valid, `user_id` is returned. If invalid or missing, `None` is returned.
- In remote-hosted mode, `None` causes a `RuntimeError`.
5. **`user_id` stored in context** via `ctx.set_state("user_id", user_id)`.
6. **Session key derived** by `get_session_key(ctx)`:
- Priority: `client_id` (if available) > `user:{user_id}` > `"global"`.
- The `user:{user_id}` fallback ensures session isolation when MCP transports don't provide stable client IDs.
7. **Active Unity instance looked up** from `_active_by_key` dict using the session key. If none is set, `_maybe_autoselect_instance` is called (but returns `None` in remote-hosted mode).
8. **Instance injected** via `ctx.set_state("unity_instance", active_instance)`.
9. **Tool executes**, reading the instance from `ctx.get_state("unity_instance")`.
10. **Command routed** through `PluginHub.send_command_for_instance(unity_instance, ..., user_id=user_id)`, which resolves the session using `PluginRegistry.get_session_id_by_hash(project_hash, user_id)`.
## WebSocket Auth Flow
When a Unity plugin connects via WebSocket:
```
Plugin -> WS /hub/plugin (with X-API-Key header)
|
v
PluginHub.on_connect()
|
+-- config.http_remote_hosted && ApiKeyService.is_initialized()?
| |
| +-- No -> accept() (local mode, no auth needed)
| |
| +-- Yes -> read X-API-Key from headers
| |
| +-- No key -> close(4401, "API key required")
| |
| +-- ApiKeyService.validate(key)
| |
| +-- valid=True -> websocket.state.user_id = user_id
| | accept()
| |
| +-- valid=False, "unavailable" in error
| | -> close(1013, "Try again later")
| |
| +-- valid=False -> close(4403, "Invalid API key")
```
After acceptance, when the plugin sends a `register` message, `_handle_register` reads `websocket.state.user_id` and passes it to `PluginRegistry.register()`.
## Session Registry Design
### Local Mode
```
project_hash -> session_id
"abc123" -> "uuid-1"
"def456" -> "uuid-2"
```
A single `_hash_to_session` dict. Any user can see any session. `list_sessions(user_id=None)` returns all sessions.
### Remote-Hosted Mode
```
(user_id, project_hash) -> session_id
("user-A", "abc123") -> "uuid-1"
("user-B", "abc123") -> "uuid-3" (same project, different user)
("user-A", "def456") -> "uuid-2"
```
A separate `_user_hash_to_session` dict with composite keys. Two users working on cloned repos (same `project_hash`) get independent sessions.
### Reconnect Handling
When a Unity editor reconnects (e.g., after domain reload), `register()` detects the existing mapping for the same key and evicts the old session before inserting the new one. This ensures the latest WebSocket connection always wins.
### list_sessions Guard
`list_sessions(user_id=None)` raises `ValueError` when `config.http_remote_hosted` is `True`. This prevents code paths from accidentally listing all users' sessions. Every call site in remote-hosted mode must pass an explicit `user_id`.
## Caching Strategy
`ApiKeyService` maintains an in-memory cache:
```python
# api_key -> (valid, user_id, metadata, expires_at)
_cache: dict[str, tuple[bool, str | None, dict | None, float]]
```
### What Gets Cached
| Response | Cached? | Rationale |
|----------|---------|-----------|
| 200 + `valid: true` | Yes | Definitive valid result |
| 200 + `valid: false` | Yes | Definitive invalid result |
| 401 status | Yes | Definitive rejection |
| 5xx status | No | Transient; retry on next request |
| Timeout | No | Transient; retry on next request |
| Connection error | No | Transient; retry on next request |
| Unexpected exception | No | Transient; retry on next request |
Non-cacheable results use `ValidationResult(cacheable=False)`.
### Cache Lifecycle
- **TTL:** Configurable via `--api-key-cache-ttl` (default: 300 seconds).
- **Expiry:** Checked on read. Expired entries are deleted and re-validated.
- **Invalidation:** `invalidate_cache(api_key)` removes a single key. `clear_cache()` removes all.
- **Concurrency:** Protected by `asyncio.Lock`.
### Revocation Latency
A revoked key continues to work for up to `cache_ttl` seconds. Lower the TTL for faster revocation at the cost of more validation requests.
## Fail-Closed Behaviour
The system fails closed at every boundary:
| Component | Failure | Behaviour |
|-----------|---------|-----------|
| `ApiKeyService._validate_external` | Timeout after retries | `valid=False, cacheable=False` |
| `ApiKeyService._validate_external` | Connection error after retries | `valid=False, cacheable=False` |
| `ApiKeyService._validate_external` | 5xx status | `valid=False, cacheable=False` |
| `ApiKeyService._validate_external` | Unexpected exception | `valid=False, cacheable=False` |
| `PluginHub.on_connect` | Auth service unavailable | Close `1013` (retry hint) |
| `UnityInstanceMiddleware._inject_unity_instance` | No user_id in remote-hosted mode | `RuntimeError` |
API keys are never logged in full. Keys longer than 8 characters are redacted to `xxxx...yyyy` in log messages.
## Session Key Derivation
`UnityInstanceMiddleware.get_session_key(ctx)` determines which dict key to use for storing/retrieving the active Unity instance per session:
```
1. client_id (string, non-empty) -> return client_id
2. ctx.get_state("user_id") -> return "user:{user_id}"
3. fallback -> return "global"
```
- **`client_id`:** Stable per MCP client connection. Preferred when available.
- **`user:{user_id}`:** Used in remote-hosted mode when the MCP transport doesn't provide a stable client ID. Ensures different users don't share instance selections.
- **`"global"`:** Local-dev fallback for single-user scenarios. Unreachable in remote-hosted mode because the auth enforcement raises `RuntimeError` before this point if no `user_id` is available.
## Disabled Features in Remote-Hosted Mode
| Feature | Local Mode | Remote-Hosted Mode | Reason |
|---------|-----------|-------------------|--------|
| Auto-select sole instance | Enabled | Disabled | Implicit behaviour is dangerous with multiple users |
| CLI REST routes | Enabled | Disabled | No auth layer on these endpoints |
| `list_sessions(user_id=None)` | Returns all | Raises `ValueError` | Prevents accidental cross-user session leaks |
## Configuration Flow
```
CLI args / env vars
|
v
main.py: parser.parse_args()
|
+-- config.http_remote_hosted = args or env
+-- config.api_key_validation_url = args or env
+-- config.api_key_login_url = args or env
+-- config.api_key_cache_ttl = args or env (float)
+-- config.api_key_service_token_header = args or env
+-- config.api_key_service_token = args or env
|
+-- Validate: remote-hosted requires validation URL
| (exits with code 1 if missing)
|
v
create_mcp_server()
|
+-- get_unity_instance_middleware() -> registers middleware
|
+-- if remote-hosted + validation URL:
| ApiKeyService(
| validation_url, cache_ttl,
| service_token_header, service_token
| )
|
+-- WebSocketRoute("/hub/plugin", PluginHub)
|
+-- if not remote-hosted:
register CLI routes (/api/command, /api/instances, /api/custom-tools)
```
## Key Files
| File | Role |
|------|------|
| `Server/src/core/config.py` | `ServerConfig` dataclass with auth fields |
| `Server/src/main.py` | CLI argument parsing, startup validation, service initialization |
| `Server/src/services/api_key_service.py` | API key validation singleton with caching and retry |
| `Server/src/transport/plugin_hub.py` | WebSocket auth gate, user-scoped session queries |
| `Server/src/transport/plugin_registry.py` | Dual-index session storage (local + user-scoped) |
| `Server/src/transport/unity_instance_middleware.py` | Per-request user_id and instance injection |
| `Server/src/transport/unity_transport.py` | `_resolve_user_id_from_request` helper |