Remote server auth (#644)
* Disable the gloabl default to first session when hosting remotely
* Remove calls to /plugin/sessions
The newer /api/instances covers that data, and we want to remove these "expose all" endpoints
* Disable CLI routes when running in remote hosted mode
* Update server README
* feat: add API key authentication support for remote-hosted HTTP transport
- Add API key field to connection UI (visible only in HTTP Remote mode)
- Add "Get API Key" and "Clear" buttons with login URL retrieval
- Include X-API-Key header in WebSocket connections when configured
- Add API key to CLI commands (mcp add, claude mcp add) when set
- Update config.json generation to include headers with API key
- Add API key validation service with caching and configurable endpoints
- Add /api/auth/login-url endpoint
* feat: add environment variable support for HTTP remote hosted mode
- Add UNITY_MCP_HTTP_REMOTE_HOSTED environment variable as alternative to --http-remote-hosted flag
- Accept "true", "1", or "yes" values (case-insensitive)
- Update CLI help text to document environment variable option
* feat: add user isolation enforcement for remote-hosted mode session listing
- Raise ValueError when list_sessions() called without user_id in remote-hosted mode
- Add comprehensive integration tests for multi-user session isolation
- Add unit tests for PluginRegistry user-scoped session filtering
- Verify cross-user isolation with same project hash
- Test unity_instances resource and set_active_instance user filtering
* feat: add comprehensive integration tests for API key authentication
- Add ApiKeyService tests covering validation, caching, retries, and singleton lifecycle
- Add startup config validation tests for remote-hosted mode requirements
- Test cache hit/miss scenarios, TTL expiration, and manual invalidation
- Test transient failure handling (5xx, timeouts, connection errors) with retry logic
- Test service token header injection and empty key fast-path validation
- Test startup validation requiring
* test: add autouse fixture to restore config state after startup validation tests
Ensures test isolation for config-dependent integration tests
* feat: skip user_id resolution in non-remote-hosted mode
Prevents unnecessary API key validation when not in remote-hosted mode
* test: add missing mock attributes to instance routing tests
- Add client_id to test context mock in set_active_instance test
- Add get_state mock to context in global instance routing test
* Fix broken telemetry test
* Add comprehensive API key authentication documentation
- Add user guide covering configuration, setup, and troubleshooting
- Add architecture reference documenting internal design and request flows
* Add remote-hosted mode and API key authentication documentation to server README
* Update reference doc for Docker Hub
* Specify exception being caught
* Ensure caplog handler cleanup in telemetry queue worker test
* Use NoUnitySessionError instead of RuntimeError in session isolation test
* Remove unusued monkeypatch arg
* Use more obviously fake API keys
* Reject connections when ApiKeyService is not initialized in remote-hosted mode
- Validate that user_id is present after successful key validation
- Expand transient error detection to include timeout and service errors
- Use consistent 1013 status code for retryable auth failures
* Accept "on" for UNITY_MCP_HTTP_REMOTE_HOSTED env var
Consistent with repo
* Invalidate cached login URL when HTTP base URL changes
* Pass API key as parameter instead of reading from EditorPrefs in RegisterWithCapturedValues
* Cache API key in field instead of reading from EditorPrefs on each reconnection
* Align markdown table formatting in remote server auth documentation
* Minor fixes
* security: Sanitize API key values in shell commands and fix minor issues
Add SanitizeShellHeaderValue() method to escape special shell characters (", \, `, $, !) in API keys before including them in shell command arguments. Apply sanitization to all three locations where API keys are embedded in shell commands (two in RegisterWithCapturedValues, one in GetManualInstructions).
Also fix deprecated passwordCharacter property (now maskChar) and improve exception logging in _resolve_user_id_from_request
* Consolidate duplicate instance selection error messages into InstanceSelectionRequiredError class
Add InstanceSelectionRequiredError exception class with centralized error messages (_SELECTION_REQUIRED and _MULTIPLE_INSTANCES). Replace 4 duplicate RuntimeError raises with new exception type. Update tests to catch InstanceSelectionRequiredError instead of RuntimeError.
* Replace hardcoded "X-API-Key" strings with AuthConstants.ApiKeyHeader constant across C# and Python codebases
Add AuthConstants class in C# and API_KEY_HEADER constant in Python to centralize the API key header name definition. Update all 8 locations where "X-API-Key" was hardcoded (4 in C#, 4 in Python) to use the new constants instead.
* Fix imports
* Filter session listing by user_id in all code paths to prevent cross-user session access
Remove conditional logic that only filtered sessions by user_id in remote-hosted mode. Now all session listings are filtered by user_id regardless of hosting mode, ensuring users can only see and interact with their own sessions.
* Consolidate get_session_id_by_hash methods into single method with optional user_id parameter
Merge get_session_id_by_hash and get_session_id_by_user_hash into a single method that accepts an optional user_id parameter. Update all call sites to use the unified method signature with user_id as the second parameter. Update tests and documentation to reflect the simplified API.
* Add environment variable support for project-scoped-tools flag [skip ci]
Support UNITY_MCP_PROJECT_SCOPED_TOOLS environment variable as alternative to --project-scoped-tools command line flag. Accept "true", "1", "yes", or "on" as truthy values (case-insensitive). Update help text to document the environment variable option.
* Fix Python tests
* Update validation logic to only require API key validation URL when both http_remote_hosted is enabled AND transport mode is "http", preventing false validation errors in stdio mode.
* Update Server/src/main.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
* Refactor HTTP transport configuration to support separate local and remote URLs
Split HTTP transport into HttpLocal and HttpRemote modes with separate EditorPrefs storage (HttpBaseUrl and HttpRemoteBaseUrl).
Add HttpEndpointUtility methods to get/save local and remote URLs independently, and introduce IsRemoteScope() and GetCurrentServerTransport() helpers to centralize 3-way transport determination (Stdio/Http/HttpRemote). Update all client configuration code to distinguish between local and remote HTTP
* Only include API key headers in HTTP/WebSocket configuration when in remote-hosted mode
Update all locations where API key headers are added to HTTP/WebSocket configurations to check HttpEndpointUtility.IsRemoteScope() or serverTransport == HttpRemote before including the API key. This prevents local HTTP mode from unnecessarily including API key headers in shell commands, config JSON, and WebSocket connections.
* Hide Manual Server Launch foldout when not in HTTP Local mode
* Fix failing test
* Improve error messaging and API key validation for HTTP Remote transport
Add detailed error messages to WebSocket connection failures that guide users to check server URL, server status, and API key validity. Store error state in TransportState for propagation to UI. Disable "Start Session" button when HTTP Remote mode is selected without an API key, with tooltip explaining requirement. Display error dialog on connection failure with specific error message from transport state. Update connection
* Add missing .meta file
* Store transport mode in ServerConfig instead of environment variable
* Add autouse fixture to restore global config state between tests
Add restore_global_config fixture in conftest.py that automatically saves and restores global config attributes and UNITY_MCP_TRANSPORT environment variable between tests. Update integration tests to use monkeypatch.setattr on config.transport_mode instead of monkeypatch.setenv to prevent test pollution and ensure clean state isolation.
* Fix startup
* Replace _current_transport() calls with direct config.transport_mode access
* Minor cleanup
* Add integration tests for HTTP transport authentication behavior
Verify that HTTP local mode allows requests without user_id while HTTP remote-hosted mode rejects them with auth_required error.
* Add smoke tests for transport routing paths across HTTP local, HTTP remote, and stdio modes
Verify that HTTP local routes through PluginHub without user_id, HTTP remote routes through PluginHub with user_id, and stdio calls legacy send function with instance_id. Each test uses monkeypatch to configure transport mode and mock appropriate transport layer functions.
---------
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
main
parent
8ee9700327
commit
664a43b76c
|
|
@ -0,0 +1,11 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 14a4b9a7f749248d496466c2a3a53e56
|
||||||
|
MonoImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
serializedVersion: 2
|
||||||
|
defaultReferences: []
|
||||||
|
executionOrder: 0
|
||||||
|
icon: {instanceID: 0}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
||||||
|
|
@ -155,9 +155,19 @@ namespace MCPForUnity.Editor.Clients
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Stdio;
|
client.configuredTransport = Models.ConfiguredTransport.Stdio;
|
||||||
}
|
}
|
||||||
else if (!string.IsNullOrEmpty(configuredUrl))
|
else if (!string.IsNullOrEmpty(configuredUrl))
|
||||||
|
{
|
||||||
|
// Distinguish HTTP Local from HTTP Remote by matching against both URLs
|
||||||
|
string localRpcUrl = HttpEndpointUtility.GetLocalMcpRpcUrl();
|
||||||
|
string remoteRpcUrl = HttpEndpointUtility.GetRemoteMcpRpcUrl();
|
||||||
|
if (!string.IsNullOrEmpty(remoteRpcUrl) && UrlsEqual(configuredUrl, remoteRpcUrl))
|
||||||
|
{
|
||||||
|
client.configuredTransport = Models.ConfiguredTransport.HttpRemote;
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Http;
|
client.configuredTransport = Models.ConfiguredTransport.Http;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Unknown;
|
client.configuredTransport = Models.ConfiguredTransport.Unknown;
|
||||||
|
|
@ -173,6 +183,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
}
|
}
|
||||||
else if (!string.IsNullOrEmpty(configuredUrl))
|
else if (!string.IsNullOrEmpty(configuredUrl))
|
||||||
{
|
{
|
||||||
|
// Match against the active scope's URL
|
||||||
string expectedUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
string expectedUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
||||||
matches = UrlsEqual(configuredUrl, expectedUrl);
|
matches = UrlsEqual(configuredUrl, expectedUrl);
|
||||||
}
|
}
|
||||||
|
|
@ -189,9 +200,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
if (result == "Configured successfully")
|
if (result == "Configured successfully")
|
||||||
{
|
{
|
||||||
client.SetStatus(McpStatus.Configured);
|
client.SetStatus(McpStatus.Configured);
|
||||||
// Update transport after rewrite based on current server setting
|
client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
|
||||||
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -220,9 +229,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
if (result == "Configured successfully")
|
if (result == "Configured successfully")
|
||||||
{
|
{
|
||||||
client.SetStatus(McpStatus.Configured);
|
client.SetStatus(McpStatus.Configured);
|
||||||
// Set transport based on current server setting
|
client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
|
||||||
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -271,9 +278,18 @@ namespace MCPForUnity.Editor.Clients
|
||||||
{
|
{
|
||||||
// Determine and set the configured transport type
|
// Determine and set the configured transport type
|
||||||
if (!string.IsNullOrEmpty(url))
|
if (!string.IsNullOrEmpty(url))
|
||||||
|
{
|
||||||
|
// Distinguish HTTP Local from HTTP Remote
|
||||||
|
string remoteRpcUrl = HttpEndpointUtility.GetRemoteMcpRpcUrl();
|
||||||
|
if (!string.IsNullOrEmpty(remoteRpcUrl) && UrlsEqual(url, remoteRpcUrl))
|
||||||
|
{
|
||||||
|
client.configuredTransport = Models.ConfiguredTransport.HttpRemote;
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Http;
|
client.configuredTransport = Models.ConfiguredTransport.Http;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else if (args != null && args.Length > 0)
|
else if (args != null && args.Length > 0)
|
||||||
{
|
{
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Stdio;
|
client.configuredTransport = Models.ConfiguredTransport.Stdio;
|
||||||
|
|
@ -286,6 +302,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
bool matches = false;
|
bool matches = false;
|
||||||
if (!string.IsNullOrEmpty(url))
|
if (!string.IsNullOrEmpty(url))
|
||||||
{
|
{
|
||||||
|
// Match against the active scope's URL
|
||||||
matches = UrlsEqual(url, HttpEndpointUtility.GetMcpRpcUrl());
|
matches = UrlsEqual(url, HttpEndpointUtility.GetMcpRpcUrl());
|
||||||
}
|
}
|
||||||
else if (args != null && args.Length > 0)
|
else if (args != null && args.Length > 0)
|
||||||
|
|
@ -313,9 +330,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
if (result == "Configured successfully")
|
if (result == "Configured successfully")
|
||||||
{
|
{
|
||||||
client.SetStatus(McpStatus.Configured);
|
client.SetStatus(McpStatus.Configured);
|
||||||
// Update transport after rewrite based on current server setting
|
client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
|
||||||
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -344,9 +359,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
if (result == "Configured successfully")
|
if (result == "Configured successfully")
|
||||||
{
|
{
|
||||||
client.SetStatus(McpStatus.Configured);
|
client.SetStatus(McpStatus.Configured);
|
||||||
// Set transport based on current server setting
|
client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
|
||||||
client.configuredTransport = useHttp ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -468,9 +481,13 @@ namespace MCPForUnity.Editor.Clients
|
||||||
bool registeredWithStdio = getStdout.Contains("Type: stdio", StringComparison.OrdinalIgnoreCase);
|
bool registeredWithStdio = getStdout.Contains("Type: stdio", StringComparison.OrdinalIgnoreCase);
|
||||||
|
|
||||||
// Set the configured transport based on what we detected
|
// Set the configured transport based on what we detected
|
||||||
|
// For HTTP, we can't distinguish local/remote from CLI output alone,
|
||||||
|
// so infer from the current scope setting when HTTP is detected.
|
||||||
if (registeredWithHttp)
|
if (registeredWithHttp)
|
||||||
{
|
{
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Http;
|
client.configuredTransport = HttpEndpointUtility.IsRemoteScope()
|
||||||
|
? Models.ConfiguredTransport.HttpRemote
|
||||||
|
: Models.ConfiguredTransport.Http;
|
||||||
}
|
}
|
||||||
else if (registeredWithStdio)
|
else if (registeredWithStdio)
|
||||||
{
|
{
|
||||||
|
|
@ -481,7 +498,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
client.configuredTransport = Models.ConfiguredTransport.Unknown;
|
client.configuredTransport = Models.ConfiguredTransport.Unknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for transport mismatch
|
// Check for transport mismatch (3-way: Stdio, Http, HttpRemote)
|
||||||
bool hasTransportMismatch = (currentUseHttp && registeredWithStdio) || (!currentUseHttp && registeredWithHttp);
|
bool hasTransportMismatch = (currentUseHttp && registeredWithStdio) || (!currentUseHttp && registeredWithHttp);
|
||||||
|
|
||||||
// For stdio transport, also check package version
|
// For stdio transport, also check package version
|
||||||
|
|
@ -575,7 +592,9 @@ namespace MCPForUnity.Editor.Clients
|
||||||
public void ConfigureWithCapturedValues(
|
public void ConfigureWithCapturedValues(
|
||||||
string projectDir, string claudePath, string pathPrepend,
|
string projectDir, string claudePath, string pathPrepend,
|
||||||
bool useHttpTransport, string httpUrl,
|
bool useHttpTransport, string httpUrl,
|
||||||
string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh)
|
string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh,
|
||||||
|
string apiKey,
|
||||||
|
Models.ConfiguredTransport serverTransport)
|
||||||
{
|
{
|
||||||
if (client.status == McpStatus.Configured)
|
if (client.status == McpStatus.Configured)
|
||||||
{
|
{
|
||||||
|
|
@ -584,7 +603,8 @@ namespace MCPForUnity.Editor.Clients
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
RegisterWithCapturedValues(projectDir, claudePath, pathPrepend,
|
RegisterWithCapturedValues(projectDir, claudePath, pathPrepend,
|
||||||
useHttpTransport, httpUrl, uvxPath, gitUrl, packageName, shouldForceRefresh);
|
useHttpTransport, httpUrl, uvxPath, gitUrl, packageName, shouldForceRefresh,
|
||||||
|
apiKey, serverTransport);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -594,7 +614,9 @@ namespace MCPForUnity.Editor.Clients
|
||||||
private void RegisterWithCapturedValues(
|
private void RegisterWithCapturedValues(
|
||||||
string projectDir, string claudePath, string pathPrepend,
|
string projectDir, string claudePath, string pathPrepend,
|
||||||
bool useHttpTransport, string httpUrl,
|
bool useHttpTransport, string httpUrl,
|
||||||
string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh)
|
string uvxPath, string gitUrl, string packageName, bool shouldForceRefresh,
|
||||||
|
string apiKey,
|
||||||
|
Models.ConfiguredTransport serverTransport)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(claudePath))
|
if (string.IsNullOrEmpty(claudePath))
|
||||||
{
|
{
|
||||||
|
|
@ -603,9 +625,18 @@ namespace MCPForUnity.Editor.Clients
|
||||||
|
|
||||||
string args;
|
string args;
|
||||||
if (useHttpTransport)
|
if (useHttpTransport)
|
||||||
|
{
|
||||||
|
// Only include API key header for remote-hosted mode
|
||||||
|
if (serverTransport == Models.ConfiguredTransport.HttpRemote && !string.IsNullOrEmpty(apiKey))
|
||||||
|
{
|
||||||
|
string safeKey = SanitizeShellHeaderValue(apiKey);
|
||||||
|
args = $"mcp add --transport http UnityMCP {httpUrl} --header \"{AuthConstants.ApiKeyHeader}: {safeKey}\"";
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
args = $"mcp add --transport http UnityMCP {httpUrl}";
|
args = $"mcp add --transport http UnityMCP {httpUrl}";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// Note: --reinstall is not supported by uvx, use --no-cache --refresh instead
|
// Note: --reinstall is not supported by uvx, use --no-cache --refresh instead
|
||||||
|
|
@ -626,7 +657,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
|
|
||||||
McpLog.Info($"Successfully registered with Claude Code using {(useHttpTransport ? "HTTP" : "stdio")} transport.");
|
McpLog.Info($"Successfully registered with Claude Code using {(useHttpTransport ? "HTTP" : "stdio")} transport.");
|
||||||
client.SetStatus(McpStatus.Configured);
|
client.SetStatus(McpStatus.Configured);
|
||||||
client.configuredTransport = useHttpTransport ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
|
client.configuredTransport = serverTransport;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -664,8 +695,25 @@ namespace MCPForUnity.Editor.Clients
|
||||||
if (useHttpTransport)
|
if (useHttpTransport)
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
||||||
|
// Only include API key header for remote-hosted mode
|
||||||
|
if (HttpEndpointUtility.IsRemoteScope())
|
||||||
|
{
|
||||||
|
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
if (!string.IsNullOrEmpty(apiKey))
|
||||||
|
{
|
||||||
|
string safeKey = SanitizeShellHeaderValue(apiKey);
|
||||||
|
args = $"mcp add --transport http UnityMCP {httpUrl} --header \"{AuthConstants.ApiKeyHeader}: {safeKey}\"";
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
args = $"mcp add --transport http UnityMCP {httpUrl}";
|
args = $"mcp add --transport http UnityMCP {httpUrl}";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
args = $"mcp add --transport http UnityMCP {httpUrl}";
|
||||||
|
}
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts();
|
var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts();
|
||||||
|
|
@ -715,7 +763,7 @@ namespace MCPForUnity.Editor.Clients
|
||||||
// Set status to Configured immediately after successful registration
|
// Set status to Configured immediately after successful registration
|
||||||
// The UI will trigger an async verification check separately to avoid blocking
|
// The UI will trigger an async verification check separately to avoid blocking
|
||||||
client.SetStatus(McpStatus.Configured);
|
client.SetStatus(McpStatus.Configured);
|
||||||
client.configuredTransport = useHttpTransport ? Models.ConfiguredTransport.Http : Models.ConfiguredTransport.Stdio;
|
client.configuredTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void Unregister()
|
private void Unregister()
|
||||||
|
|
@ -757,8 +805,15 @@ namespace MCPForUnity.Editor.Clients
|
||||||
if (useHttpTransport)
|
if (useHttpTransport)
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
||||||
|
// Only include API key header for remote-hosted mode
|
||||||
|
string headerArg = "";
|
||||||
|
if (HttpEndpointUtility.IsRemoteScope())
|
||||||
|
{
|
||||||
|
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
headerArg = !string.IsNullOrEmpty(apiKey) ? $" --header \"{AuthConstants.ApiKeyHeader}: {SanitizeShellHeaderValue(apiKey)}\"" : "";
|
||||||
|
}
|
||||||
return "# Register the MCP server with Claude Code:\n" +
|
return "# Register the MCP server with Claude Code:\n" +
|
||||||
$"claude mcp add --transport http UnityMCP {httpUrl}\n\n" +
|
$"claude mcp add --transport http UnityMCP {httpUrl}{headerArg}\n\n" +
|
||||||
"# Unregister the MCP server:\n" +
|
"# Unregister the MCP server:\n" +
|
||||||
"claude mcp remove UnityMCP\n\n" +
|
"claude mcp remove UnityMCP\n\n" +
|
||||||
"# List registered servers:\n" +
|
"# List registered servers:\n" +
|
||||||
|
|
@ -790,6 +845,37 @@ namespace MCPForUnity.Editor.Clients
|
||||||
"Restart Claude Code"
|
"Restart Claude Code"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Sanitizes a value for safe inclusion inside a double-quoted shell argument.
|
||||||
|
/// Escapes characters that are special within double quotes (", \, `, $, !)
|
||||||
|
/// to prevent shell injection or argument splitting.
|
||||||
|
/// </summary>
|
||||||
|
private static string SanitizeShellHeaderValue(string value)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(value))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
var sb = new System.Text.StringBuilder(value.Length);
|
||||||
|
foreach (char c in value)
|
||||||
|
{
|
||||||
|
switch (c)
|
||||||
|
{
|
||||||
|
case '"':
|
||||||
|
case '\\':
|
||||||
|
case '`':
|
||||||
|
case '$':
|
||||||
|
case '!':
|
||||||
|
sb.Append('\\');
|
||||||
|
sb.Append(c);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
sb.Append(c);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Extracts the package source (--from argument value) from claude mcp get output.
|
/// Extracts the package source (--from argument value) from claude mcp get output.
|
||||||
/// The output format includes args like: --from "mcpforunityserver==9.0.1"
|
/// The output format includes args like: --from "mcpforunityserver==9.0.1"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
namespace MCPForUnity.Editor.Constants
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Protocol-level constants for API key authentication.
|
||||||
|
/// </summary>
|
||||||
|
internal static class AuthConstants
|
||||||
|
{
|
||||||
|
internal const string ApiKeyHeader = "X-API-Key";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 96844bc39e9a94cf18b18f8127f3854f
|
||||||
|
MonoImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
serializedVersion: 2
|
||||||
|
defaultReferences: []
|
||||||
|
executionOrder: 0
|
||||||
|
icon: {instanceID: 0}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
||||||
|
|
@ -24,6 +24,7 @@ namespace MCPForUnity.Editor.Constants
|
||||||
internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath";
|
internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath";
|
||||||
|
|
||||||
internal const string HttpBaseUrl = "MCPForUnity.HttpUrl";
|
internal const string HttpBaseUrl = "MCPForUnity.HttpUrl";
|
||||||
|
internal const string HttpRemoteBaseUrl = "MCPForUnity.HttpRemoteUrl";
|
||||||
internal const string SessionId = "MCPForUnity.SessionId";
|
internal const string SessionId = "MCPForUnity.SessionId";
|
||||||
internal const string WebSocketUrlOverride = "MCPForUnity.WebSocketUrl";
|
internal const string WebSocketUrlOverride = "MCPForUnity.WebSocketUrl";
|
||||||
internal const string GitUrlOverride = "MCPForUnity.GitUrlOverride";
|
internal const string GitUrlOverride = "MCPForUnity.GitUrlOverride";
|
||||||
|
|
@ -55,5 +56,7 @@ namespace MCPForUnity.Editor.Constants
|
||||||
|
|
||||||
internal const string TelemetryDisabled = "MCPForUnity.TelemetryDisabled";
|
internal const string TelemetryDisabled = "MCPForUnity.TelemetryDisabled";
|
||||||
internal const string CustomerUuid = "MCPForUnity.CustomerUUID";
|
internal const string CustomerUuid = "MCPForUnity.CustomerUUID";
|
||||||
|
|
||||||
|
internal const string ApiKey = "MCPForUnity.ApiKey";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,26 @@ namespace MCPForUnity.Editor.Helpers
|
||||||
if (unity["command"] != null) unity.Remove("command");
|
if (unity["command"] != null) unity.Remove("command");
|
||||||
if (unity["args"] != null) unity.Remove("args");
|
if (unity["args"] != null) unity.Remove("args");
|
||||||
|
|
||||||
|
// Only include API key header for remote-hosted mode
|
||||||
|
if (HttpEndpointUtility.IsRemoteScope())
|
||||||
|
{
|
||||||
|
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
if (!string.IsNullOrEmpty(apiKey))
|
||||||
|
{
|
||||||
|
var headers = new JObject { [AuthConstants.ApiKeyHeader] = apiKey };
|
||||||
|
unity["headers"] = headers;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (unity["headers"] != null) unity.Remove("headers");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Local HTTP doesn't use API keys; remove any stale headers
|
||||||
|
if (unity["headers"] != null) unity.Remove("headers");
|
||||||
|
}
|
||||||
|
|
||||||
if (isVSCode)
|
if (isVSCode)
|
||||||
{
|
{
|
||||||
unity["type"] = "http";
|
unity["type"] = "http";
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
using System;
|
using System;
|
||||||
using MCPForUnity.Editor.Constants;
|
using MCPForUnity.Editor.Constants;
|
||||||
|
using MCPForUnity.Editor.Models;
|
||||||
|
using MCPForUnity.Editor.Services;
|
||||||
using UnityEditor;
|
using UnityEditor;
|
||||||
|
|
||||||
namespace MCPForUnity.Editor.Helpers
|
namespace MCPForUnity.Editor.Helpers
|
||||||
|
|
@ -8,38 +10,113 @@ namespace MCPForUnity.Editor.Helpers
|
||||||
/// Helper methods for managing HTTP endpoint URLs used by the MCP bridge.
|
/// Helper methods for managing HTTP endpoint URLs used by the MCP bridge.
|
||||||
/// Ensures the stored value is always the base URL (without trailing path),
|
/// Ensures the stored value is always the base URL (without trailing path),
|
||||||
/// and provides convenience accessors for specific endpoints.
|
/// and provides convenience accessors for specific endpoints.
|
||||||
|
///
|
||||||
|
/// HTTP Local and HTTP Remote use separate EditorPrefs keys so that switching
|
||||||
|
/// between scopes does not overwrite the other scope's URL.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static class HttpEndpointUtility
|
public static class HttpEndpointUtility
|
||||||
{
|
{
|
||||||
private const string PrefKey = EditorPrefKeys.HttpBaseUrl;
|
private const string LocalPrefKey = EditorPrefKeys.HttpBaseUrl;
|
||||||
private const string DefaultBaseUrl = "http://localhost:8080";
|
private const string RemotePrefKey = EditorPrefKeys.HttpRemoteBaseUrl;
|
||||||
|
private const string DefaultLocalBaseUrl = "http://localhost:8080";
|
||||||
|
private const string DefaultRemoteBaseUrl = "";
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Returns the normalized base URL currently stored in EditorPrefs.
|
/// Returns the normalized base URL for the currently active HTTP scope.
|
||||||
|
/// If the scope is "remote", returns the remote URL; otherwise returns the local URL.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static string GetBaseUrl()
|
public static string GetBaseUrl()
|
||||||
{
|
{
|
||||||
string stored = EditorPrefs.GetString(PrefKey, DefaultBaseUrl);
|
return IsRemoteScope() ? GetRemoteBaseUrl() : GetLocalBaseUrl();
|
||||||
return NormalizeBaseUrl(stored);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Saves a user-provided URL after normalizing it to a base form.
|
/// Saves a user-provided URL to the currently active HTTP scope's pref.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static void SaveBaseUrl(string userValue)
|
public static void SaveBaseUrl(string userValue)
|
||||||
{
|
{
|
||||||
string normalized = NormalizeBaseUrl(userValue);
|
if (IsRemoteScope())
|
||||||
EditorPrefs.SetString(PrefKey, normalized);
|
{
|
||||||
|
SaveRemoteBaseUrl(userValue);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
SaveLocalBaseUrl(userValue);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Builds the JSON-RPC endpoint used by FastMCP clients (base + /mcp).
|
/// Returns the normalized local HTTP base URL (always reads local pref).
|
||||||
|
/// </summary>
|
||||||
|
public static string GetLocalBaseUrl()
|
||||||
|
{
|
||||||
|
string stored = EditorPrefs.GetString(LocalPrefKey, DefaultLocalBaseUrl);
|
||||||
|
return NormalizeBaseUrl(stored, DefaultLocalBaseUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Saves a user-provided URL to the local HTTP pref.
|
||||||
|
/// </summary>
|
||||||
|
public static void SaveLocalBaseUrl(string userValue)
|
||||||
|
{
|
||||||
|
string normalized = NormalizeBaseUrl(userValue, DefaultLocalBaseUrl);
|
||||||
|
EditorPrefs.SetString(LocalPrefKey, normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns the normalized remote HTTP base URL (always reads remote pref).
|
||||||
|
/// Returns empty string if no remote URL is configured.
|
||||||
|
/// </summary>
|
||||||
|
public static string GetRemoteBaseUrl()
|
||||||
|
{
|
||||||
|
string stored = EditorPrefs.GetString(RemotePrefKey, DefaultRemoteBaseUrl);
|
||||||
|
if (string.IsNullOrWhiteSpace(stored))
|
||||||
|
{
|
||||||
|
return DefaultRemoteBaseUrl;
|
||||||
|
}
|
||||||
|
return NormalizeBaseUrl(stored, DefaultRemoteBaseUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Saves a user-provided URL to the remote HTTP pref.
|
||||||
|
/// </summary>
|
||||||
|
public static void SaveRemoteBaseUrl(string userValue)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(userValue))
|
||||||
|
{
|
||||||
|
EditorPrefs.SetString(RemotePrefKey, DefaultRemoteBaseUrl);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
string normalized = NormalizeBaseUrl(userValue, DefaultRemoteBaseUrl);
|
||||||
|
EditorPrefs.SetString(RemotePrefKey, normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Builds the JSON-RPC endpoint for the currently active scope (base + /mcp).
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static string GetMcpRpcUrl()
|
public static string GetMcpRpcUrl()
|
||||||
{
|
{
|
||||||
return AppendPathSegment(GetBaseUrl(), "mcp");
|
return AppendPathSegment(GetBaseUrl(), "mcp");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Builds the local JSON-RPC endpoint (local base + /mcp).
|
||||||
|
/// </summary>
|
||||||
|
public static string GetLocalMcpRpcUrl()
|
||||||
|
{
|
||||||
|
return AppendPathSegment(GetLocalBaseUrl(), "mcp");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Builds the remote JSON-RPC endpoint (remote base + /mcp).
|
||||||
|
/// Returns empty string if no remote URL is configured.
|
||||||
|
/// </summary>
|
||||||
|
public static string GetRemoteMcpRpcUrl()
|
||||||
|
{
|
||||||
|
string remoteBase = GetRemoteBaseUrl();
|
||||||
|
return string.IsNullOrEmpty(remoteBase) ? string.Empty : AppendPathSegment(remoteBase, "mcp");
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Builds the endpoint used when POSTing custom-tool registration payloads.
|
/// Builds the endpoint used when POSTing custom-tool registration payloads.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
@ -48,14 +125,35 @@ namespace MCPForUnity.Editor.Helpers
|
||||||
return AppendPathSegment(GetBaseUrl(), "register-tools");
|
return AppendPathSegment(GetBaseUrl(), "register-tools");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns true if the active HTTP transport scope is "remote".
|
||||||
|
/// </summary>
|
||||||
|
public static bool IsRemoteScope()
|
||||||
|
{
|
||||||
|
string scope = EditorConfigurationCache.Instance.HttpTransportScope;
|
||||||
|
return string.Equals(scope, "remote", StringComparison.OrdinalIgnoreCase);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns the <see cref="ConfiguredTransport"/> that matches the current server-side
|
||||||
|
/// transport selection (Stdio, Http, or HttpRemote).
|
||||||
|
/// Centralises the 3-way determination so callers avoid duplicated logic.
|
||||||
|
/// </summary>
|
||||||
|
public static ConfiguredTransport GetCurrentServerTransport()
|
||||||
|
{
|
||||||
|
bool useHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
||||||
|
if (!useHttp) return ConfiguredTransport.Stdio;
|
||||||
|
return IsRemoteScope() ? ConfiguredTransport.HttpRemote : ConfiguredTransport.Http;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Normalizes a URL so that we consistently store just the base (no trailing slash/path).
|
/// Normalizes a URL so that we consistently store just the base (no trailing slash/path).
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private static string NormalizeBaseUrl(string value)
|
private static string NormalizeBaseUrl(string value, string defaultUrl)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrWhiteSpace(value))
|
if (string.IsNullOrWhiteSpace(value))
|
||||||
{
|
{
|
||||||
return DefaultBaseUrl;
|
return defaultUrl;
|
||||||
}
|
}
|
||||||
|
|
||||||
string trimmed = value.Trim();
|
string trimmed = value.Trim();
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,8 @@ namespace MCPForUnity.Editor.Models
|
||||||
{
|
{
|
||||||
Unknown, // Could not determine transport type
|
Unknown, // Could not determine transport type
|
||||||
Stdio, // Client configured for stdio transport
|
Stdio, // Client configured for stdio transport
|
||||||
Http // Client configured for HTTP transport
|
Http, // Client configured for HTTP local transport
|
||||||
|
HttpRemote // Client configured for HTTP remote-hosted transport
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
private string _uvxPathOverride;
|
private string _uvxPathOverride;
|
||||||
private string _gitUrlOverride;
|
private string _gitUrlOverride;
|
||||||
private string _httpBaseUrl;
|
private string _httpBaseUrl;
|
||||||
|
private string _httpRemoteBaseUrl;
|
||||||
private string _claudeCliPathOverride;
|
private string _claudeCliPathOverride;
|
||||||
private string _httpTransportScope;
|
private string _httpTransportScope;
|
||||||
private int _unitySocketPort;
|
private int _unitySocketPort;
|
||||||
|
|
@ -94,11 +95,17 @@ namespace MCPForUnity.Editor.Services
|
||||||
public string GitUrlOverride => _gitUrlOverride;
|
public string GitUrlOverride => _gitUrlOverride;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// HTTP base URL for the MCP server.
|
/// HTTP base URL for the local MCP server.
|
||||||
/// Default: empty string
|
/// Default: empty string
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string HttpBaseUrl => _httpBaseUrl;
|
public string HttpBaseUrl => _httpBaseUrl;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// HTTP base URL for the remote-hosted MCP server.
|
||||||
|
/// Default: empty string
|
||||||
|
/// </summary>
|
||||||
|
public string HttpRemoteBaseUrl => _httpRemoteBaseUrl;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Custom path override for Claude CLI executable.
|
/// Custom path override for Claude CLI executable.
|
||||||
/// Default: empty string (auto-detect)
|
/// Default: empty string (auto-detect)
|
||||||
|
|
@ -135,6 +142,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
_uvxPathOverride = EditorPrefs.GetString(EditorPrefKeys.UvxPathOverride, string.Empty);
|
_uvxPathOverride = EditorPrefs.GetString(EditorPrefKeys.UvxPathOverride, string.Empty);
|
||||||
_gitUrlOverride = EditorPrefs.GetString(EditorPrefKeys.GitUrlOverride, string.Empty);
|
_gitUrlOverride = EditorPrefs.GetString(EditorPrefKeys.GitUrlOverride, string.Empty);
|
||||||
_httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty);
|
_httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty);
|
||||||
|
_httpRemoteBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpRemoteBaseUrl, string.Empty);
|
||||||
_claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty);
|
_claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty);
|
||||||
_httpTransportScope = EditorPrefs.GetString(EditorPrefKeys.HttpTransportScope, string.Empty);
|
_httpTransportScope = EditorPrefs.GetString(EditorPrefKeys.HttpTransportScope, string.Empty);
|
||||||
_unitySocketPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0);
|
_unitySocketPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0);
|
||||||
|
|
@ -234,6 +242,20 @@ namespace MCPForUnity.Editor.Services
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Set HttpRemoteBaseUrl and update cache + EditorPrefs atomically.
|
||||||
|
/// </summary>
|
||||||
|
public void SetHttpRemoteBaseUrl(string value)
|
||||||
|
{
|
||||||
|
value = value ?? string.Empty;
|
||||||
|
if (_httpRemoteBaseUrl != value)
|
||||||
|
{
|
||||||
|
_httpRemoteBaseUrl = value;
|
||||||
|
EditorPrefs.SetString(EditorPrefKeys.HttpRemoteBaseUrl, value);
|
||||||
|
OnConfigurationChanged?.Invoke(nameof(HttpRemoteBaseUrl));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Set ClaudeCliPathOverride and update cache + EditorPrefs atomically.
|
/// Set ClaudeCliPathOverride and update cache + EditorPrefs atomically.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
@ -304,6 +326,9 @@ namespace MCPForUnity.Editor.Services
|
||||||
case nameof(HttpBaseUrl):
|
case nameof(HttpBaseUrl):
|
||||||
_httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty);
|
_httpBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpBaseUrl, string.Empty);
|
||||||
break;
|
break;
|
||||||
|
case nameof(HttpRemoteBaseUrl):
|
||||||
|
_httpRemoteBaseUrl = EditorPrefs.GetString(EditorPrefKeys.HttpRemoteBaseUrl, string.Empty);
|
||||||
|
break;
|
||||||
case nameof(ClaudeCliPathOverride):
|
case nameof(ClaudeCliPathOverride):
|
||||||
_claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty);
|
_claudeCliPathOverride = EditorPrefs.GetString(EditorPrefKeys.ClaudeCliPathOverride, string.Empty);
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ namespace MCPForUnity.Editor.Services.Server
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
string httpUrl = HttpEndpointUtility.GetBaseUrl();
|
string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
if (!IsLocalUrl(httpUrl))
|
if (!IsLocalUrl(httpUrl))
|
||||||
{
|
{
|
||||||
error = $"The configured URL ({httpUrl}) is not a local address. Local server launch only works for localhost.";
|
error = $"The configured URL ({httpUrl}) is not a local address. Local server launch only works for localhost.";
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
using System;
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using System.Collections.Generic;
|
|
||||||
using System.Net.Sockets;
|
using System.Net.Sockets;
|
||||||
using MCPForUnity.Editor.Constants;
|
using MCPForUnity.Editor.Constants;
|
||||||
using MCPForUnity.Editor.Helpers;
|
using MCPForUnity.Editor.Helpers;
|
||||||
|
|
@ -158,7 +158,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
|
|
||||||
if (success)
|
if (success)
|
||||||
{
|
{
|
||||||
McpLog.Debug($"uv cache cleared successfully: {stdout}");
|
McpLog.Info($"uv cache cleared successfully: {stdout}");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
string combinedOutput = string.Join(
|
string combinedOutput = string.Join(
|
||||||
|
|
@ -253,7 +253,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
// If the port is still occupied, don't start and explain why (avoid confusing "refusing to stop" warnings).
|
// If the port is still occupied, don't start and explain why (avoid confusing "refusing to stop" warnings).
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetBaseUrl();
|
string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
if (Uri.TryCreate(httpUrl, UriKind.Absolute, out var uri) && uri.Port > 0)
|
if (Uri.TryCreate(httpUrl, UriKind.Absolute, out var uri) && uri.Port > 0)
|
||||||
{
|
{
|
||||||
var remaining = GetListeningProcessIdsForPort(uri.Port);
|
var remaining = GetListeningProcessIdsForPort(uri.Port);
|
||||||
|
|
@ -274,7 +274,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
// Note: Dev mode cache-busting is handled by `uvx --no-cache --refresh` in the generated command.
|
// Note: Dev mode cache-busting is handled by `uvx --no-cache --refresh` in the generated command.
|
||||||
|
|
||||||
// Create a per-launch token + pidfile path so Stop can be deterministic without relying on port/PID heuristics.
|
// Create a per-launch token + pidfile path so Stop can be deterministic without relying on port/PID heuristics.
|
||||||
string baseUrlForPid = HttpEndpointUtility.GetBaseUrl();
|
string baseUrlForPid = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
Uri.TryCreate(baseUrlForPid, UriKind.Absolute, out var uriForPid);
|
Uri.TryCreate(baseUrlForPid, UriKind.Absolute, out var uriForPid);
|
||||||
int portForPid = uriForPid?.Port ?? 0;
|
int portForPid = uriForPid?.Port ?? 0;
|
||||||
string instanceToken = Guid.NewGuid().ToString("N");
|
string instanceToken = Guid.NewGuid().ToString("N");
|
||||||
|
|
@ -350,7 +350,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
int port = 0;
|
int port = 0;
|
||||||
if (!TryGetPortFromPidFilePath(pidFilePath, out port) || port <= 0)
|
if (!TryGetPortFromPidFilePath(pidFilePath, out port) || port <= 0)
|
||||||
{
|
{
|
||||||
string baseUrl = HttpEndpointUtility.GetBaseUrl();
|
string baseUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
if (IsLocalUrl(baseUrl)
|
if (IsLocalUrl(baseUrl)
|
||||||
&& Uri.TryCreate(baseUrl, UriKind.Absolute, out var uri)
|
&& Uri.TryCreate(baseUrl, UriKind.Absolute, out var uri)
|
||||||
&& uri.Port > 0)
|
&& uri.Port > 0)
|
||||||
|
|
@ -371,7 +371,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetBaseUrl();
|
string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
if (!IsLocalUrl(httpUrl))
|
if (!IsLocalUrl(httpUrl))
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -433,7 +433,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetBaseUrl();
|
string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
if (!IsLocalUrl(httpUrl))
|
if (!IsLocalUrl(httpUrl))
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -500,7 +500,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
|
|
||||||
private bool StopLocalHttpServerInternal(bool quiet, int? portOverride = null, bool allowNonLocalUrl = false)
|
private bool StopLocalHttpServerInternal(bool quiet, int? portOverride = null, bool allowNonLocalUrl = false)
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetBaseUrl();
|
string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
if (!allowNonLocalUrl && !IsLocalUrl(httpUrl))
|
if (!allowNonLocalUrl && !IsLocalUrl(httpUrl))
|
||||||
{
|
{
|
||||||
if (!quiet)
|
if (!quiet)
|
||||||
|
|
@ -836,7 +836,7 @@ namespace MCPForUnity.Editor.Services
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public bool IsLocalUrl()
|
public bool IsLocalUrl()
|
||||||
{
|
{
|
||||||
string httpUrl = HttpEndpointUtility.GetBaseUrl();
|
string httpUrl = HttpEndpointUtility.GetLocalBaseUrl();
|
||||||
return IsLocalUrl(httpUrl);
|
return IsLocalUrl(httpUrl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ namespace MCPForUnity.Editor.Services.Transport
|
||||||
{
|
{
|
||||||
McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}");
|
McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}");
|
||||||
}
|
}
|
||||||
UpdateState(mode, TransportState.Disconnected(client.TransportName, "Failed to start"));
|
UpdateState(mode, TransportState.Disconnected(client.TransportName, client.State?.Error ?? "Failed to start"));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ using System.Net.WebSockets;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
|
using MCPForUnity.Editor.Constants;
|
||||||
using MCPForUnity.Editor.Helpers;
|
using MCPForUnity.Editor.Helpers;
|
||||||
using MCPForUnity.Editor.Services;
|
using MCPForUnity.Editor.Services;
|
||||||
using MCPForUnity.Editor.Services.Transport;
|
using MCPForUnity.Editor.Services.Transport;
|
||||||
|
|
@ -56,6 +57,7 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
|
||||||
private volatile bool _isConnected;
|
private volatile bool _isConnected;
|
||||||
private int _isReconnectingFlag;
|
private int _isReconnectingFlag;
|
||||||
private TransportState _state = TransportState.Disconnected(TransportDisplayName, "Transport not started");
|
private TransportState _state = TransportState.Disconnected(TransportDisplayName, "Transport not started");
|
||||||
|
private string _apiKey;
|
||||||
private bool _disposed;
|
private bool _disposed;
|
||||||
|
|
||||||
public WebSocketTransportClient(IToolDiscoveryService toolDiscoveryService = null)
|
public WebSocketTransportClient(IToolDiscoveryService toolDiscoveryService = null)
|
||||||
|
|
@ -80,6 +82,9 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
|
||||||
_projectName = ProjectIdentityUtility.GetProjectName();
|
_projectName = ProjectIdentityUtility.GetProjectName();
|
||||||
_projectHash = ProjectIdentityUtility.GetProjectHash();
|
_projectHash = ProjectIdentityUtility.GetProjectHash();
|
||||||
_unityVersion = Application.unityVersion;
|
_unityVersion = Application.unityVersion;
|
||||||
|
_apiKey = HttpEndpointUtility.IsRemoteScope()
|
||||||
|
? EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty)
|
||||||
|
: string.Empty;
|
||||||
|
|
||||||
// Get project root path (strip /Assets from dataPath) for focus nudging
|
// Get project root path (strip /Assets from dataPath) for focus nudging
|
||||||
string dataPath = Application.dataPath;
|
string dataPath = Application.dataPath;
|
||||||
|
|
@ -214,13 +219,21 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
|
||||||
_socket = new ClientWebSocket();
|
_socket = new ClientWebSocket();
|
||||||
_socket.Options.KeepAliveInterval = _socketKeepAliveInterval;
|
_socket.Options.KeepAliveInterval = _socketKeepAliveInterval;
|
||||||
|
|
||||||
|
// Add API key header if configured (for remote-hosted mode)
|
||||||
|
if (!string.IsNullOrEmpty(_apiKey))
|
||||||
|
{
|
||||||
|
_socket.Options.SetRequestHeader(AuthConstants.ApiKeyHeader, _apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await _socket.ConnectAsync(_endpointUri, connectionToken).ConfigureAwait(false);
|
await _socket.ConnectAsync(_endpointUri, connectionToken).ConfigureAwait(false);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
McpLog.Error($"[WebSocket] Connection failed: {ex.Message}");
|
string errorMsg = "Connection failed. Check that the server URL is correct, the server is running, and your API key (if required) is valid.";
|
||||||
|
McpLog.Error($"[WebSocket] {errorMsg} (Detail: {ex.Message})");
|
||||||
|
_state = TransportState.Disconnected(TransportDisplayName, errorMsg);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -232,7 +245,9 @@ namespace MCPForUnity.Editor.Services.Transport.Transports
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
McpLog.Error($"[WebSocket] Registration failed: {ex.Message}");
|
string regMsg = $"Registration with server failed: {ex.Message}";
|
||||||
|
McpLog.Error($"[WebSocket] {regMsg}");
|
||||||
|
_state = TransportState.Disconnected(TransportDisplayName, regMsg);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -275,6 +275,7 @@ namespace MCPForUnity.Editor.Windows.Components.ClientConfig
|
||||||
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
string httpUrl = HttpEndpointUtility.GetMcpRpcUrl();
|
||||||
var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts();
|
var (uvxPath, gitUrl, packageName) = AssetPathUtility.GetUvxCommandParts();
|
||||||
bool shouldForceRefresh = AssetPathUtility.ShouldForceUvxRefresh();
|
bool shouldForceRefresh = AssetPathUtility.ShouldForceUvxRefresh();
|
||||||
|
string apiKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
|
||||||
// Compute pathPrepend on main thread
|
// Compute pathPrepend on main thread
|
||||||
string pathPrepend = null;
|
string pathPrepend = null;
|
||||||
|
|
@ -296,10 +297,12 @@ namespace MCPForUnity.Editor.Windows.Components.ClientConfig
|
||||||
{
|
{
|
||||||
if (client is ClaudeCliMcpConfigurator cliConfigurator)
|
if (client is ClaudeCliMcpConfigurator cliConfigurator)
|
||||||
{
|
{
|
||||||
|
var serverTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
cliConfigurator.ConfigureWithCapturedValues(
|
cliConfigurator.ConfigureWithCapturedValues(
|
||||||
projectDir, claudePath, pathPrepend,
|
projectDir, claudePath, pathPrepend,
|
||||||
useHttpTransport, httpUrl,
|
useHttpTransport, httpUrl,
|
||||||
uvxPath, gitUrl, packageName, shouldForceRefresh);
|
uvxPath, gitUrl, packageName, shouldForceRefresh,
|
||||||
|
apiKey, serverTransport);
|
||||||
}
|
}
|
||||||
return (success: true, error: (string)null);
|
return (success: true, error: (string)null);
|
||||||
}
|
}
|
||||||
|
|
@ -525,12 +528,11 @@ namespace MCPForUnity.Editor.Windows.Components.ClientConfig
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for transport mismatch
|
// Check for transport mismatch (3-way: Stdio, Http, HttpRemote)
|
||||||
bool hasTransportMismatch = false;
|
bool hasTransportMismatch = false;
|
||||||
if (client.ConfiguredTransport != ConfiguredTransport.Unknown)
|
if (client.ConfiguredTransport != ConfiguredTransport.Unknown)
|
||||||
{
|
{
|
||||||
bool serverUsesHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
ConfiguredTransport serverTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
ConfiguredTransport serverTransport = serverUsesHttp ? ConfiguredTransport.Http : ConfiguredTransport.Stdio;
|
|
||||||
hasTransportMismatch = client.ConfiguredTransport != serverTransport;
|
hasTransportMismatch = client.ConfiguredTransport != serverTransport;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,13 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
private Label connectionStatusLabel;
|
private Label connectionStatusLabel;
|
||||||
private Button connectionToggleButton;
|
private Button connectionToggleButton;
|
||||||
|
|
||||||
|
// API Key UI Elements (for remote-hosted mode)
|
||||||
|
private VisualElement apiKeyRow;
|
||||||
|
private TextField apiKeyField;
|
||||||
|
private Button getApiKeyButton;
|
||||||
|
private Button clearApiKeyButton;
|
||||||
|
private string cachedLoginUrl;
|
||||||
|
|
||||||
private bool connectionToggleInProgress;
|
private bool connectionToggleInProgress;
|
||||||
private bool httpServerToggleInProgress;
|
private bool httpServerToggleInProgress;
|
||||||
private Task verificationTask;
|
private Task verificationTask;
|
||||||
|
|
@ -93,6 +100,12 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
statusIndicator = Root.Q<VisualElement>("status-indicator");
|
statusIndicator = Root.Q<VisualElement>("status-indicator");
|
||||||
connectionStatusLabel = Root.Q<Label>("connection-status");
|
connectionStatusLabel = Root.Q<Label>("connection-status");
|
||||||
connectionToggleButton = Root.Q<Button>("connection-toggle");
|
connectionToggleButton = Root.Q<Button>("connection-toggle");
|
||||||
|
|
||||||
|
// API Key UI Elements
|
||||||
|
apiKeyRow = Root.Q<VisualElement>("api-key-row");
|
||||||
|
apiKeyField = Root.Q<TextField>("api-key-field");
|
||||||
|
getApiKeyButton = Root.Q<Button>("get-api-key-button");
|
||||||
|
clearApiKeyButton = Root.Q<Button>("clear-api-key-button");
|
||||||
}
|
}
|
||||||
|
|
||||||
private void InitializeUI()
|
private void InitializeUI()
|
||||||
|
|
@ -139,6 +152,15 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
|
|
||||||
httpUrlField.value = HttpEndpointUtility.GetBaseUrl();
|
httpUrlField.value = HttpEndpointUtility.GetBaseUrl();
|
||||||
|
|
||||||
|
// Initialize API key field
|
||||||
|
if (apiKeyField != null)
|
||||||
|
{
|
||||||
|
apiKeyField.value = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
apiKeyField.tooltip = "API key for remote-hosted MCP server authentication";
|
||||||
|
apiKeyField.isPasswordField = true;
|
||||||
|
apiKeyField.maskChar = '*';
|
||||||
|
}
|
||||||
|
|
||||||
int unityPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0);
|
int unityPort = EditorPrefs.GetInt(EditorPrefKeys.UnitySocketPort, 0);
|
||||||
if (unityPort == 0)
|
if (unityPort == 0)
|
||||||
{
|
{
|
||||||
|
|
@ -170,6 +192,8 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
EditorConfigurationCache.Instance.SetHttpTransportScope(scope);
|
EditorConfigurationCache.Instance.SetHttpTransportScope(scope);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Swap the displayed URL to match the newly selected scope
|
||||||
|
SyncUrlFieldToScope();
|
||||||
UpdateHttpFieldVisibility();
|
UpdateHttpFieldVisibility();
|
||||||
RefreshHttpUi();
|
RefreshHttpUi();
|
||||||
UpdateConnectionStatus();
|
UpdateConnectionStatus();
|
||||||
|
|
@ -247,6 +271,30 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
});
|
});
|
||||||
|
|
||||||
connectionToggleButton.clicked += OnConnectionToggleClicked;
|
connectionToggleButton.clicked += OnConnectionToggleClicked;
|
||||||
|
|
||||||
|
// API Key field callbacks
|
||||||
|
if (apiKeyField != null)
|
||||||
|
{
|
||||||
|
apiKeyField.RegisterCallback<FocusOutEvent>(_ => PersistApiKeyFromField());
|
||||||
|
apiKeyField.RegisterCallback<KeyDownEvent>(evt =>
|
||||||
|
{
|
||||||
|
if (evt.keyCode == KeyCode.Return || evt.keyCode == KeyCode.KeypadEnter)
|
||||||
|
{
|
||||||
|
PersistApiKeyFromField();
|
||||||
|
evt.StopPropagation();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getApiKeyButton != null)
|
||||||
|
{
|
||||||
|
getApiKeyButton.clicked += OnGetApiKeyClicked;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (clearApiKeyButton != null)
|
||||||
|
{
|
||||||
|
clearApiKeyButton.clicked += OnClearApiKeyClicked;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void PersistHttpUrlFromField()
|
private void PersistHttpUrlFromField()
|
||||||
|
|
@ -259,6 +307,8 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
HttpEndpointUtility.SaveBaseUrl(httpUrlField.text);
|
HttpEndpointUtility.SaveBaseUrl(httpUrlField.text);
|
||||||
// Update displayed value to normalized form without re-triggering callbacks/caret jumps.
|
// Update displayed value to normalized form without re-triggering callbacks/caret jumps.
|
||||||
httpUrlField.SetValueWithoutNotify(HttpEndpointUtility.GetBaseUrl());
|
httpUrlField.SetValueWithoutNotify(HttpEndpointUtility.GetBaseUrl());
|
||||||
|
// Invalidate cached login URL so it is re-fetched for the new base URL.
|
||||||
|
cachedLoginUrl = null;
|
||||||
OnManualConfigUpdateRequested?.Invoke();
|
OnManualConfigUpdateRequested?.Invoke();
|
||||||
RefreshHttpUi();
|
RefreshHttpUi();
|
||||||
}
|
}
|
||||||
|
|
@ -338,7 +388,15 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
statusIndicator.RemoveFromClassList("connected");
|
statusIndicator.RemoveFromClassList("connected");
|
||||||
statusIndicator.AddToClassList("disconnected");
|
statusIndicator.AddToClassList("disconnected");
|
||||||
connectionToggleButton.text = "Start Session";
|
connectionToggleButton.text = "Start Session";
|
||||||
connectionToggleButton.SetEnabled(true);
|
|
||||||
|
// Disable Start Session for HTTP Remote when no API key is set
|
||||||
|
bool httpRemoteNeedsKey = transportDropdown != null
|
||||||
|
&& (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPRemote
|
||||||
|
&& string.IsNullOrEmpty(EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty));
|
||||||
|
connectionToggleButton.SetEnabled(!httpRemoteNeedsKey);
|
||||||
|
connectionToggleButton.tooltip = httpRemoteNeedsKey
|
||||||
|
? "An API key is required for HTTP Remote. Enter one above."
|
||||||
|
: string.Empty;
|
||||||
}
|
}
|
||||||
|
|
||||||
unityPortField.SetEnabled(!isStdioResuming);
|
unityPortField.SetEnabled(!isStdioResuming);
|
||||||
|
|
@ -439,10 +497,19 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
{
|
{
|
||||||
bool useHttp = (TransportProtocol)transportDropdown.value != TransportProtocol.Stdio;
|
bool useHttp = (TransportProtocol)transportDropdown.value != TransportProtocol.Stdio;
|
||||||
bool httpLocalSelected = IsHttpLocalSelected();
|
bool httpLocalSelected = IsHttpLocalSelected();
|
||||||
|
bool httpRemoteSelected = transportDropdown != null && (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPRemote;
|
||||||
|
|
||||||
httpUrlRow.style.display = useHttp ? DisplayStyle.Flex : DisplayStyle.None;
|
httpUrlRow.style.display = useHttp ? DisplayStyle.Flex : DisplayStyle.None;
|
||||||
httpServerControlRow.style.display = useHttp && httpLocalSelected ? DisplayStyle.Flex : DisplayStyle.None;
|
httpServerControlRow.style.display = useHttp && httpLocalSelected ? DisplayStyle.Flex : DisplayStyle.None;
|
||||||
unitySocketPortRow.style.display = useHttp ? DisplayStyle.None : DisplayStyle.Flex;
|
unitySocketPortRow.style.display = useHttp ? DisplayStyle.None : DisplayStyle.Flex;
|
||||||
|
|
||||||
|
// Manual Server Launch foldout only relevant for HTTP Local
|
||||||
|
if (manualCommandFoldout != null)
|
||||||
|
manualCommandFoldout.style.display = httpLocalSelected ? DisplayStyle.Flex : DisplayStyle.None;
|
||||||
|
|
||||||
|
// API key fields only visible in HTTP Remote mode
|
||||||
|
if (apiKeyRow != null)
|
||||||
|
apiKeyRow.style.display = httpRemoteSelected ? DisplayStyle.Flex : DisplayStyle.None;
|
||||||
}
|
}
|
||||||
|
|
||||||
private bool IsHttpLocalSelected()
|
private bool IsHttpLocalSelected()
|
||||||
|
|
@ -450,6 +517,13 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
return transportDropdown != null && (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPLocal;
|
return transportDropdown != null && (TransportProtocol)transportDropdown.value == TransportProtocol.HTTPLocal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void SyncUrlFieldToScope()
|
||||||
|
{
|
||||||
|
if (httpUrlField == null) return;
|
||||||
|
httpUrlField.SetValueWithoutNotify(HttpEndpointUtility.GetBaseUrl());
|
||||||
|
cachedLoginUrl = null;
|
||||||
|
}
|
||||||
|
|
||||||
private void UpdateStartHttpButtonState()
|
private void UpdateStartHttpButtonState()
|
||||||
{
|
{
|
||||||
if (startHttpServerButton == null)
|
if (startHttpServerButton == null)
|
||||||
|
|
@ -674,7 +748,13 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
McpLog.Warn("Failed to start MCP bridge");
|
var mode = EditorConfigurationCache.Instance.UseHttpTransport
|
||||||
|
? TransportMode.Http : TransportMode.Stdio;
|
||||||
|
var state = MCPServiceLocator.TransportManager.GetState(mode);
|
||||||
|
string errorMsg = state?.Error
|
||||||
|
?? "Failed to start the MCP session. Check the server URL and that the server is running.";
|
||||||
|
EditorUtility.DisplayDialog("Connection Failed", errorMsg, "OK");
|
||||||
|
McpLog.Warn($"Failed to start MCP bridge: {errorMsg}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -720,6 +800,110 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void PersistApiKeyFromField()
|
||||||
|
{
|
||||||
|
if (apiKeyField == null)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
string apiKey = apiKeyField.text?.Trim() ?? string.Empty;
|
||||||
|
string existingKey = EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
|
||||||
|
if (apiKey != existingKey)
|
||||||
|
{
|
||||||
|
EditorPrefs.SetString(EditorPrefKeys.ApiKey, apiKey);
|
||||||
|
OnManualConfigUpdateRequested?.Invoke();
|
||||||
|
UpdateConnectionStatus();
|
||||||
|
McpLog.Info(string.IsNullOrEmpty(apiKey) ? "API key cleared" : "API key updated");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async void OnGetApiKeyClicked()
|
||||||
|
{
|
||||||
|
if (getApiKeyButton != null)
|
||||||
|
{
|
||||||
|
getApiKeyButton.SetEnabled(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
string loginUrl = await GetLoginUrlAsync();
|
||||||
|
if (string.IsNullOrEmpty(loginUrl))
|
||||||
|
{
|
||||||
|
EditorUtility.DisplayDialog("API Key",
|
||||||
|
"API key management is not available for this server. Contact your server administrator.",
|
||||||
|
"OK");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Application.OpenURL(loginUrl);
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
McpLog.Error($"Failed to get login URL: {ex.Message}");
|
||||||
|
EditorUtility.DisplayDialog("Error",
|
||||||
|
$"Failed to get API key login URL:\n\n{ex.Message}",
|
||||||
|
"OK");
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
if (getApiKeyButton != null)
|
||||||
|
{
|
||||||
|
getApiKeyButton.SetEnabled(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<string> GetLoginUrlAsync()
|
||||||
|
{
|
||||||
|
if (!string.IsNullOrEmpty(cachedLoginUrl))
|
||||||
|
{
|
||||||
|
return cachedLoginUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
string baseUrl = HttpEndpointUtility.GetBaseUrl();
|
||||||
|
string loginUrlEndpoint = $"{baseUrl.TrimEnd('/')}/api/auth/login-url";
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
using (var client = new System.Net.Http.HttpClient())
|
||||||
|
{
|
||||||
|
client.Timeout = TimeSpan.FromSeconds(10);
|
||||||
|
var response = await client.GetAsync(loginUrlEndpoint);
|
||||||
|
|
||||||
|
if (response.IsSuccessStatusCode)
|
||||||
|
{
|
||||||
|
string json = await response.Content.ReadAsStringAsync();
|
||||||
|
var result = Newtonsoft.Json.Linq.JObject.Parse(json);
|
||||||
|
|
||||||
|
if (result.Value<bool>("success"))
|
||||||
|
{
|
||||||
|
cachedLoginUrl = result.Value<string>("login_url");
|
||||||
|
return cachedLoginUrl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
McpLog.Debug($"Failed to fetch login URL from {loginUrlEndpoint}: {ex.Message}");
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void OnClearApiKeyClicked()
|
||||||
|
{
|
||||||
|
EditorPrefs.SetString(EditorPrefKeys.ApiKey, string.Empty);
|
||||||
|
if (apiKeyField != null)
|
||||||
|
{
|
||||||
|
apiKeyField.SetValueWithoutNotify(string.Empty);
|
||||||
|
}
|
||||||
|
OnManualConfigUpdateRequested?.Invoke();
|
||||||
|
UpdateConnectionStatus();
|
||||||
|
McpLog.Info("API key cleared");
|
||||||
|
}
|
||||||
|
|
||||||
public async Task VerifyBridgeConnectionAsync()
|
public async Task VerifyBridgeConnectionAsync()
|
||||||
{
|
{
|
||||||
// Prevent concurrent verification calls
|
// Prevent concurrent verification calls
|
||||||
|
|
@ -810,17 +994,16 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the server's current transport setting
|
// Determine the server's current transport setting (3-way: Stdio, Http, HttpRemote)
|
||||||
bool serverUsesHttp = EditorConfigurationCache.Instance.UseHttpTransport;
|
ConfiguredTransport serverTransport = HttpEndpointUtility.GetCurrentServerTransport();
|
||||||
ConfiguredTransport serverTransport = serverUsesHttp ? ConfiguredTransport.Http : ConfiguredTransport.Stdio;
|
|
||||||
|
|
||||||
// Check for mismatch
|
// Check for mismatch
|
||||||
bool hasMismatch = clientTransport != serverTransport;
|
bool hasMismatch = clientTransport != serverTransport;
|
||||||
|
|
||||||
if (hasMismatch)
|
if (hasMismatch)
|
||||||
{
|
{
|
||||||
string clientTransportName = clientTransport == ConfiguredTransport.Http ? "HTTP" : "stdio";
|
string clientTransportName = TransportDisplayName(clientTransport);
|
||||||
string serverTransportName = serverTransport == ConfiguredTransport.Http ? "HTTP" : "stdio";
|
string serverTransportName = TransportDisplayName(serverTransport);
|
||||||
|
|
||||||
transportMismatchText.text = $"⚠ {clientName} is configured for \"{clientTransportName}\" but server is set to \"{serverTransportName}\". " +
|
transportMismatchText.text = $"⚠ {clientName} is configured for \"{clientTransportName}\" but server is set to \"{serverTransportName}\". " +
|
||||||
"Click \"Configure\" in Client Configuration to update.";
|
"Click \"Configure\" in Client Configuration to update.";
|
||||||
|
|
@ -839,5 +1022,16 @@ namespace MCPForUnity.Editor.Windows.Components.Connection
|
||||||
{
|
{
|
||||||
transportMismatchWarning?.RemoveFromClassList("visible");
|
transportMismatchWarning?.RemoveFromClassList("visible");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static string TransportDisplayName(ConfiguredTransport transport)
|
||||||
|
{
|
||||||
|
return transport switch
|
||||||
|
{
|
||||||
|
ConfiguredTransport.Stdio => "stdio",
|
||||||
|
ConfiguredTransport.Http => "HTTP Local",
|
||||||
|
ConfiguredTransport.HttpRemote => "HTTP Remote",
|
||||||
|
_ => "unknown"
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,16 @@
|
||||||
<ui:Label text="HTTP URL:" class="setting-label" />
|
<ui:Label text="HTTP URL:" class="setting-label" />
|
||||||
<ui:TextField name="http-url" class="url-field" />
|
<ui:TextField name="http-url" class="url-field" />
|
||||||
</ui:VisualElement>
|
</ui:VisualElement>
|
||||||
|
<ui:VisualElement name="api-key-row" style="margin-bottom: 4px;">
|
||||||
|
<ui:VisualElement class="setting-row">
|
||||||
|
<ui:Label text="API Key:" class="setting-label" />
|
||||||
|
<ui:TextField name="api-key-field" password="true" class="url-field" />
|
||||||
|
</ui:VisualElement>
|
||||||
|
<ui:VisualElement style="flex-direction: row; justify-content: flex-end;">
|
||||||
|
<ui:Button name="get-api-key-button" text="Get API Key" class="action-button" />
|
||||||
|
<ui:Button name="clear-api-key-button" text="Clear" class="action-button" />
|
||||||
|
</ui:VisualElement>
|
||||||
|
</ui:VisualElement>
|
||||||
<ui:VisualElement class="setting-row" name="http-server-control-row">
|
<ui:VisualElement class="setting-row" name="http-server-control-row">
|
||||||
<ui:Label text="Local Server:" class="setting-label" />
|
<ui:Label text="Local Server:" class="setting-label" />
|
||||||
<ui:Button name="start-http-server-button" text="Start Server" class="action-button start-server-button" />
|
<ui:Button name="start-http-server-button" text="Start Server" class="action-button start-server-button" />
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ namespace MCPForUnity.Editor.Windows
|
||||||
{ EditorPrefKeys.ClaudeCliPathOverride, EditorPrefType.String },
|
{ EditorPrefKeys.ClaudeCliPathOverride, EditorPrefType.String },
|
||||||
{ EditorPrefKeys.UvxPathOverride, EditorPrefType.String },
|
{ EditorPrefKeys.UvxPathOverride, EditorPrefType.String },
|
||||||
{ EditorPrefKeys.HttpBaseUrl, EditorPrefType.String },
|
{ EditorPrefKeys.HttpBaseUrl, EditorPrefType.String },
|
||||||
|
{ EditorPrefKeys.HttpRemoteBaseUrl, EditorPrefType.String },
|
||||||
{ EditorPrefKeys.HttpTransportScope, EditorPrefType.String },
|
{ EditorPrefKeys.HttpTransportScope, EditorPrefType.String },
|
||||||
{ EditorPrefKeys.SessionId, EditorPrefType.String },
|
{ EditorPrefKeys.SessionId, EditorPrefType.String },
|
||||||
{ EditorPrefKeys.WebSocketUrlOverride, EditorPrefType.String },
|
{ EditorPrefKeys.WebSocketUrlOverride, EditorPrefType.String },
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,54 @@ docker run -p 8080:8080 -e LOG_LEVEL=DEBUG msanatan/mcp-for-unity-server:latest
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Remote-Hosted Mode
|
||||||
|
|
||||||
|
To deploy as a shared remote service with API key authentication and per-user session isolation, pass `--http-remote-hosted` along with an API key validation URL:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run -p 8080:8080 \
|
||||||
|
-e UNITY_MCP_HTTP_REMOTE_HOSTED=true \
|
||||||
|
-e UNITY_MCP_API_KEY_VALIDATION_URL=https://auth.example.com/api/validate-key \
|
||||||
|
-e UNITY_MCP_API_KEY_LOGIN_URL=https://app.example.com/api-keys \
|
||||||
|
msanatan/mcp-for-unity-server:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
In this mode:
|
||||||
|
|
||||||
|
- All MCP tool/resource calls and Unity plugin WebSocket connections require a valid `X-API-Key` header.
|
||||||
|
- Each user only sees Unity instances that connected with their API key.
|
||||||
|
- Users must explicitly call `set_active_instance` to select a Unity instance.
|
||||||
|
|
||||||
|
**Remote-hosted environment variables:**
|
||||||
|
|
||||||
|
| Variable | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `UNITY_MCP_HTTP_REMOTE_HOSTED` | Enable remote-hosted mode (`true`, `1`, or `yes`) |
|
||||||
|
| `UNITY_MCP_API_KEY_VALIDATION_URL` | External endpoint to validate API keys (required) |
|
||||||
|
| `UNITY_MCP_API_KEY_LOGIN_URL` | URL where users can obtain/manage API keys |
|
||||||
|
| `UNITY_MCP_API_KEY_CACHE_TTL` | Cache TTL for validated keys in seconds (default: `300`) |
|
||||||
|
| `UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER` | Header name for server-to-auth-service authentication |
|
||||||
|
| `UNITY_MCP_API_KEY_SERVICE_TOKEN` | Token value sent to the auth service |
|
||||||
|
|
||||||
|
**MCP client config with API key:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"UnityMCP": {
|
||||||
|
"url": "http://your-server:8080/mcp",
|
||||||
|
"headers": {
|
||||||
|
"X-API-Key": "<your-api-key>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For full details, see the [Remote Server Auth Guide](https://github.com/CoplayDev/unity-mcp/blob/main/docs/guides/REMOTE_SERVER_AUTH.md).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Example Prompts
|
## Example Prompts
|
||||||
|
|
||||||
Once connected, try these commands in your AI assistant:
|
Once connected, try these commands in your AI assistant:
|
||||||
|
|
|
||||||
120
Server/README.md
120
Server/README.md
|
|
@ -113,12 +113,124 @@ uv run src/main.py --transport stdio
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
The server connects to Unity Editor automatically when both are running. No additional configuration needed.
|
The server connects to Unity Editor automatically when both are running. Most users do not need to change any settings.
|
||||||
|
|
||||||
**Environment Variables:**
|
### CLI options
|
||||||
|
|
||||||
- `DISABLE_TELEMETRY=true` - Opt out of anonymous usage analytics
|
These options apply to the `mcp-for-unity` command (whether run via `uvx`, Docker, or `python src/main.py`).
|
||||||
- `LOG_LEVEL=DEBUG` - Enable detailed logging (default: INFO)
|
|
||||||
|
- `--transport {stdio,http}` - Transport protocol (default: `stdio`)
|
||||||
|
- `--http-url URL` - Base URL used to derive host/port defaults (default: `http://localhost:8080`)
|
||||||
|
- `--http-host HOST` - Override HTTP bind host (overrides URL host)
|
||||||
|
- `--http-port PORT` - Override HTTP bind port (overrides URL port)
|
||||||
|
- `--http-remote-hosted` - Treat HTTP transport as remotely hosted
|
||||||
|
- Requires API key authentication (see below)
|
||||||
|
- Disables local/CLI-only HTTP routes (`/api/command`, `/api/instances`, `/api/custom-tools`)
|
||||||
|
- Forces explicit Unity instance selection for MCP tool/resource calls
|
||||||
|
- Isolates Unity sessions per user
|
||||||
|
- `--api-key-validation-url URL` - External endpoint to validate API keys (required when `--http-remote-hosted` is set)
|
||||||
|
- `--api-key-login-url URL` - URL where users can obtain/manage API keys (served by `/api/auth/login-url`)
|
||||||
|
- `--api-key-cache-ttl SECONDS` - Cache duration for validated keys (default: `300`)
|
||||||
|
- `--api-key-service-token-header HEADER` - Header name for server-to-auth-service authentication (e.g. `X-Service-Token`)
|
||||||
|
- `--api-key-service-token TOKEN` - Token value sent to the auth service for server authentication
|
||||||
|
- `--default-instance INSTANCE` - Default Unity instance to target (project name, hash, or `Name@hash`)
|
||||||
|
- `--project-scoped-tools` - Keep custom tools scoped to the active Unity project and enable the custom tools resource
|
||||||
|
- `--unity-instance-token TOKEN` - Optional per-launch token set by Unity for deterministic lifecycle management
|
||||||
|
- `--pidfile PATH` - Optional path where the server writes its PID on startup (used by Unity-managed terminal launches)
|
||||||
|
|
||||||
|
### Environment variables
|
||||||
|
|
||||||
|
- `UNITY_MCP_TRANSPORT` - Transport protocol: `stdio` or `http`
|
||||||
|
- `UNITY_MCP_HTTP_URL` - HTTP server URL (default: `http://localhost:8080`)
|
||||||
|
- `UNITY_MCP_HTTP_HOST` - HTTP bind host (overrides URL host)
|
||||||
|
- `UNITY_MCP_HTTP_PORT` - HTTP bind port (overrides URL port)
|
||||||
|
- `UNITY_MCP_HTTP_REMOTE_HOSTED` - Enable remote-hosted mode (`true`, `1`, or `yes`)
|
||||||
|
- `UNITY_MCP_DEFAULT_INSTANCE` - Default Unity instance to target (project name, hash, or `Name@hash`)
|
||||||
|
- `UNITY_MCP_SKIP_STARTUP_CONNECT=1` - Skip initial Unity connection attempt on startup
|
||||||
|
|
||||||
|
API key authentication (remote-hosted mode):
|
||||||
|
|
||||||
|
- `UNITY_MCP_API_KEY_VALIDATION_URL` - External endpoint to validate API keys
|
||||||
|
- `UNITY_MCP_API_KEY_LOGIN_URL` - URL where users can obtain/manage API keys
|
||||||
|
- `UNITY_MCP_API_KEY_CACHE_TTL` - Cache TTL for validated keys in seconds (default: `300`)
|
||||||
|
- `UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER` - Header name for server-to-auth-service authentication
|
||||||
|
- `UNITY_MCP_API_KEY_SERVICE_TOKEN` - Token value sent to the auth service for server authentication
|
||||||
|
|
||||||
|
Telemetry:
|
||||||
|
|
||||||
|
- `DISABLE_TELEMETRY=1` - Disable anonymous telemetry (opt-out)
|
||||||
|
- `UNITY_MCP_DISABLE_TELEMETRY=1` - Same as `DISABLE_TELEMETRY`
|
||||||
|
- `MCP_DISABLE_TELEMETRY=1` - Same as `DISABLE_TELEMETRY`
|
||||||
|
- `UNITY_MCP_TELEMETRY_ENDPOINT` - Override telemetry endpoint URL
|
||||||
|
- `UNITY_MCP_TELEMETRY_TIMEOUT` - Override telemetry request timeout (seconds)
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
**Stdio (default):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvx --from mcpforunityserver mcp-for-unity --transport stdio
|
||||||
|
```
|
||||||
|
|
||||||
|
**HTTP (local):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvx --from mcpforunityserver mcp-for-unity --transport http --http-host 127.0.0.1 --http-port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
**HTTP (remote-hosted with API key auth):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvx --from mcpforunityserver mcp-for-unity \
|
||||||
|
--transport http \
|
||||||
|
--http-host 0.0.0.0 \
|
||||||
|
--http-port 8080 \
|
||||||
|
--http-remote-hosted \
|
||||||
|
--api-key-validation-url https://auth.example.com/api/validate-key \
|
||||||
|
--api-key-login-url https://app.example.com/api-keys
|
||||||
|
```
|
||||||
|
|
||||||
|
**Disable telemetry:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DISABLE_TELEMETRY=1 uvx --from mcpforunityserver mcp-for-unity --transport stdio
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Remote-Hosted Mode
|
||||||
|
|
||||||
|
When deploying the server as a shared remote service (e.g. for a team or Asset Store users), enable `--http-remote-hosted` to activate API key authentication and per-user session isolation.
|
||||||
|
|
||||||
|
**Requirements:**
|
||||||
|
|
||||||
|
- An external HTTP endpoint that validates API keys. The server POSTs `{"api_key": "..."}` and expects `{"valid": true, "user_id": "..."}` or `{"valid": false}` in response.
|
||||||
|
- `--api-key-validation-url` must be provided (or `UNITY_MCP_API_KEY_VALIDATION_URL`). The server exits with code 1 if this is missing.
|
||||||
|
|
||||||
|
**What changes in remote-hosted mode:**
|
||||||
|
|
||||||
|
- All MCP tool/resource calls and Unity plugin WebSocket connections require a valid `X-API-Key` header.
|
||||||
|
- Each user only sees Unity instances that connected with their API key (session isolation).
|
||||||
|
- Auto-selection of a sole Unity instance is disabled; users must explicitly call `set_active_instance`.
|
||||||
|
- CLI REST routes (`/api/command`, `/api/instances`, `/api/custom-tools`) are disabled.
|
||||||
|
- `/health` and `/api/auth/login-url` remain accessible without authentication.
|
||||||
|
|
||||||
|
**MCP client config with API key:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"UnityMCP": {
|
||||||
|
"url": "http://remote-server:8080/mcp",
|
||||||
|
"headers": {
|
||||||
|
"X-API-Key": "<your-api-key>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For full details, see [Remote Server Auth Guide](../docs/guides/REMOTE_SERVER_AUTH.md) and [Architecture Reference](../docs/reference/REMOTE_SERVER_AUTH_ARCHITECTURE.md).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -182,38 +182,34 @@ async def list_unity_instances(config: Optional[CLIConfig] = None) -> Dict[str,
|
||||||
"""
|
"""
|
||||||
cfg = config or get_config()
|
cfg = config or get_config()
|
||||||
|
|
||||||
# Try the new /api/instances endpoint first, fall back to /plugin/sessions
|
url = f"http://{cfg.host}:{cfg.port}/api/instances"
|
||||||
urls_to_try = [
|
|
||||||
f"http://{cfg.host}:{cfg.port}/api/instances",
|
|
||||||
f"http://{cfg.host}:{cfg.port}/plugin/sessions",
|
|
||||||
]
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
for url in urls_to_try:
|
|
||||||
try:
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(url, timeout=10)
|
response = await client.get(url, timeout=10)
|
||||||
if response.status_code == 200:
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
# Normalize response format
|
|
||||||
if "instances" in data:
|
if "instances" in data:
|
||||||
return data
|
return data
|
||||||
elif "sessions" in data:
|
except httpx.ConnectError as e:
|
||||||
# Convert sessions format to instances format
|
|
||||||
instances = []
|
|
||||||
for session_id, details in data["sessions"].items():
|
|
||||||
instances.append({
|
|
||||||
"session_id": session_id,
|
|
||||||
"project": details.get("project", "Unknown"),
|
|
||||||
"hash": details.get("hash", ""),
|
|
||||||
"unity_version": details.get("unity_version", "Unknown"),
|
|
||||||
"connected_at": details.get("connected_at", ""),
|
|
||||||
})
|
|
||||||
return {"success": True, "instances": instances}
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
raise UnityConnectionError(
|
raise UnityConnectionError(
|
||||||
"Failed to list Unity instances: No working endpoint found")
|
f"Cannot connect to Unity MCP server at {cfg.host}:{cfg.port}. "
|
||||||
|
f"Make sure the server is running and Unity is connected.\n"
|
||||||
|
f"Error: {e}"
|
||||||
|
)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise UnityConnectionError(
|
||||||
|
"Connection to Unity timed out while listing instances. "
|
||||||
|
"Unity may be busy or unresponsive."
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise UnityConnectionError(
|
||||||
|
f"HTTP error from server: {e.response.status_code} - {e.response.text}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise UnityConnectionError(f"Unexpected error: {e}")
|
||||||
|
|
||||||
|
raise UnityConnectionError("Failed to list Unity instances")
|
||||||
|
|
||||||
|
|
||||||
def run_list_instances(config: Optional[CLIConfig] = None) -> Dict[str, Any]:
|
def run_list_instances(config: Optional[CLIConfig] = None) -> Dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,21 @@ class ServerConfig:
|
||||||
unity_port: int = 6400
|
unity_port: int = 6400
|
||||||
mcp_port: int = 6500
|
mcp_port: int = 6500
|
||||||
|
|
||||||
|
# Transport settings
|
||||||
|
transport_mode: str = "stdio"
|
||||||
|
|
||||||
|
# HTTP transport behaviour
|
||||||
|
http_remote_hosted: bool = False
|
||||||
|
|
||||||
|
# API key authentication (required when http_remote_hosted=True)
|
||||||
|
api_key_validation_url: str | None = None # POST endpoint to validate keys
|
||||||
|
api_key_login_url: str | None = None # URL for users to get/manage keys
|
||||||
|
# Cache TTL in seconds (5 min default)
|
||||||
|
api_key_cache_ttl: float = 300.0
|
||||||
|
# Optional service token for authenticating to the validation endpoint
|
||||||
|
api_key_service_token_header: str | None = None # e.g. "X-Service-Token"
|
||||||
|
api_key_service_token: str | None = None # The token value
|
||||||
|
|
||||||
# Connection settings
|
# Connection settings
|
||||||
connection_timeout: float = 30.0
|
connection_timeout: float = 30.0
|
||||||
buffer_size: int = 16 * 1024 * 1024 # 16MB buffer
|
buffer_size: int = 16 * 1024 * 1024 # 16MB buffer
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
"""Server-wide protocol constants."""
|
||||||
|
|
||||||
|
# HTTP header name for API key authentication
|
||||||
|
API_KEY_HEADER = "X-API-Key"
|
||||||
|
|
@ -3,6 +3,7 @@ from transport.unity_instance_middleware import (
|
||||||
UnityInstanceMiddleware,
|
UnityInstanceMiddleware,
|
||||||
get_unity_instance_middleware
|
get_unity_instance_middleware
|
||||||
)
|
)
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
from transport.legacy.unity_connection import get_unity_connection_pool, UnityConnectionPool
|
from transport.legacy.unity_connection import get_unity_connection_pool, UnityConnectionPool
|
||||||
from services.tools import register_all_tools
|
from services.tools import register_all_tools
|
||||||
from core.telemetry import record_milestone, record_telemetry, MilestoneType, RecordType, get_package_version
|
from core.telemetry import record_milestone, record_telemetry, MilestoneType, RecordType, get_package_version
|
||||||
|
|
@ -312,6 +313,15 @@ Payload sizing & paging (important):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_instance_token(instance_token: str | None) -> tuple[str | None, str | None]:
|
||||||
|
if not instance_token:
|
||||||
|
return None, None
|
||||||
|
if "@" in instance_token:
|
||||||
|
name_part, _, hash_part = instance_token.partition("@")
|
||||||
|
return (name_part or None), (hash_part or None)
|
||||||
|
return None, instance_token
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
mcp = FastMCP(
|
mcp = FastMCP(
|
||||||
name="mcp-for-unity-server",
|
name="mcp-for-unity-server",
|
||||||
|
|
@ -332,14 +342,24 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
"message": "MCP for Unity server is running"
|
"message": "MCP for Unity server is running"
|
||||||
})
|
})
|
||||||
|
|
||||||
def _normalize_instance_token(instance_token: str | None) -> tuple[str | None, str | None]:
|
@mcp.custom_route("/api/auth/login-url", methods=["GET"])
|
||||||
if not instance_token:
|
async def auth_login_url(_: Request) -> JSONResponse:
|
||||||
return None, None
|
"""Return the login URL for users to obtain/manage API keys."""
|
||||||
if "@" in instance_token:
|
if not config.api_key_login_url:
|
||||||
name_part, _, hash_part = instance_token.partition("@")
|
return JSONResponse(
|
||||||
return (name_part or None), (hash_part or None)
|
{
|
||||||
return None, instance_token
|
"success": False,
|
||||||
|
"error": "API key management not configured. Contact your server administrator.",
|
||||||
|
},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"login_url": config.api_key_login_url,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Only expose CLI routes if running locally (not in remote hosted mode)
|
||||||
|
if not config.http_remote_hosted:
|
||||||
@mcp.custom_route("/api/command", methods=["POST"])
|
@mcp.custom_route("/api/command", methods=["POST"])
|
||||||
async def cli_command_route(request: Request) -> JSONResponse:
|
async def cli_command_route(request: Request) -> JSONResponse:
|
||||||
"""REST endpoint for CLI commands to Unity."""
|
"""REST endpoint for CLI commands to Unity."""
|
||||||
|
|
@ -364,7 +384,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
# Find target session
|
# Find target session
|
||||||
session_id = None
|
session_id = None
|
||||||
session_details = None
|
session_details = None
|
||||||
instance_name, instance_hash = _normalize_instance_token(unity_instance)
|
instance_name, instance_hash = _normalize_instance_token(
|
||||||
|
unity_instance)
|
||||||
if unity_instance:
|
if unity_instance:
|
||||||
# Try to match by hash or project name
|
# Try to match by hash or project name
|
||||||
for sid, details in sessions.sessions.items():
|
for sid, details in sessions.sessions.items():
|
||||||
|
|
@ -391,19 +412,23 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
tool_name = None
|
tool_name = None
|
||||||
tool_params = {}
|
tool_params = {}
|
||||||
if isinstance(params, dict):
|
if isinstance(params, dict):
|
||||||
tool_name = params.get("tool_name") or params.get("name")
|
tool_name = params.get(
|
||||||
tool_params = params.get("parameters") or params.get("params") or {}
|
"tool_name") or params.get("name")
|
||||||
|
tool_params = params.get(
|
||||||
|
"parameters") or params.get("params") or {}
|
||||||
|
|
||||||
if not tool_name:
|
if not tool_name:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"success": False, "error": "Missing 'tool_name' for execute_custom_tool"},
|
{"success": False,
|
||||||
|
"error": "Missing 'tool_name' for execute_custom_tool"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
if tool_params is None:
|
if tool_params is None:
|
||||||
tool_params = {}
|
tool_params = {}
|
||||||
if not isinstance(tool_params, dict):
|
if not isinstance(tool_params, dict):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"success": False, "error": "Tool parameters must be an object/dict"},
|
{"success": False,
|
||||||
|
"error": "Tool parameters must be an object/dict"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -416,7 +441,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
unity_instance_hint)
|
unity_instance_hint)
|
||||||
if not project_id:
|
if not project_id:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"success": False, "error": "Could not resolve project id for custom tool"},
|
{"success": False,
|
||||||
|
"error": "Could not resolve project id for custom tool"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -431,7 +457,25 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"CLI command error: {e}")
|
logger.exception("CLI command error: %s", e)
|
||||||
|
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
@mcp.custom_route("/api/instances", methods=["GET"])
|
||||||
|
async def cli_instances_route(_: Request) -> JSONResponse:
|
||||||
|
"""REST endpoint to list connected Unity instances."""
|
||||||
|
try:
|
||||||
|
sessions = await PluginHub.get_sessions()
|
||||||
|
instances = []
|
||||||
|
for session_id, details in sessions.sessions.items():
|
||||||
|
instances.append({
|
||||||
|
"session_id": session_id,
|
||||||
|
"project": details.project,
|
||||||
|
"hash": details.hash,
|
||||||
|
"unity_version": details.unity_version,
|
||||||
|
"connected_at": details.connected_at,
|
||||||
|
})
|
||||||
|
return JSONResponse({"success": True, "instances": instances})
|
||||||
|
except Exception as e:
|
||||||
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
|
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
|
||||||
|
|
||||||
@mcp.custom_route("/api/custom-tools", methods=["GET"])
|
@mcp.custom_route("/api/custom-tools", methods=["GET"])
|
||||||
|
|
@ -439,7 +483,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
"""REST endpoint to list custom tools for the active Unity project."""
|
"""REST endpoint to list custom tools for the active Unity project."""
|
||||||
try:
|
try:
|
||||||
unity_instance = request.query_params.get("instance")
|
unity_instance = request.query_params.get("instance")
|
||||||
instance_name, instance_hash = _normalize_instance_token(unity_instance)
|
instance_name, instance_hash = _normalize_instance_token(
|
||||||
|
unity_instance)
|
||||||
|
|
||||||
sessions = await PluginHub.get_sessions()
|
sessions = await PluginHub.get_sessions()
|
||||||
if not sessions.sessions:
|
if not sessions.sessions:
|
||||||
|
|
@ -475,7 +520,8 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
unity_instance_hint)
|
unity_instance_hint)
|
||||||
if not project_id:
|
if not project_id:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"success": False, "error": "Could not resolve project id for custom tools"},
|
{"success": False,
|
||||||
|
"error": "Could not resolve project id for custom tools"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -492,38 +538,29 @@ def create_mcp_server(project_scoped_tools: bool) -> FastMCP:
|
||||||
"tools": tools_payload,
|
"tools": tools_payload,
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"CLI custom tools error: {e}")
|
logger.exception("CLI custom tools error: %s", e)
|
||||||
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
|
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
|
||||||
|
|
||||||
@mcp.custom_route("/api/instances", methods=["GET"])
|
|
||||||
async def cli_instances_route(_: Request) -> JSONResponse:
|
|
||||||
"""REST endpoint to list connected Unity instances."""
|
|
||||||
try:
|
|
||||||
sessions = await PluginHub.get_sessions()
|
|
||||||
instances = []
|
|
||||||
for session_id, details in sessions.sessions.items():
|
|
||||||
instances.append({
|
|
||||||
"session_id": session_id,
|
|
||||||
"project": details.project,
|
|
||||||
"hash": details.hash,
|
|
||||||
"unity_version": details.unity_version,
|
|
||||||
"connected_at": details.connected_at,
|
|
||||||
})
|
|
||||||
return JSONResponse({"success": True, "instances": instances})
|
|
||||||
except Exception as e:
|
|
||||||
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
|
|
||||||
|
|
||||||
@mcp.custom_route("/plugin/sessions", methods=["GET"])
|
|
||||||
async def plugin_sessions_route(_: Request) -> JSONResponse:
|
|
||||||
data = await PluginHub.get_sessions()
|
|
||||||
return JSONResponse(data.model_dump())
|
|
||||||
|
|
||||||
# Initialize and register middleware for session-based Unity instance routing
|
# Initialize and register middleware for session-based Unity instance routing
|
||||||
# Using the singleton getter ensures we use the same instance everywhere
|
# Using the singleton getter ensures we use the same instance everywhere
|
||||||
unity_middleware = get_unity_instance_middleware()
|
unity_middleware = get_unity_instance_middleware()
|
||||||
mcp.add_middleware(unity_middleware)
|
mcp.add_middleware(unity_middleware)
|
||||||
logger.info("Registered Unity instance middleware for session-based routing")
|
logger.info("Registered Unity instance middleware for session-based routing")
|
||||||
|
|
||||||
|
# Initialize API key authentication if in remote-hosted mode
|
||||||
|
if config.http_remote_hosted and config.api_key_validation_url:
|
||||||
|
ApiKeyService(
|
||||||
|
validation_url=config.api_key_validation_url,
|
||||||
|
cache_ttl=config.api_key_cache_ttl,
|
||||||
|
service_token_header=config.api_key_service_token_header,
|
||||||
|
service_token=config.api_key_service_token,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Initialized API key authentication service (validation URL: %s, TTL: %.0fs)",
|
||||||
|
config.api_key_validation_url,
|
||||||
|
config.api_key_cache_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
# Mount plugin websocket hub at /hub/plugin when HTTP transport is active
|
# Mount plugin websocket hub at /hub/plugin when HTTP transport is active
|
||||||
existing_routes = [
|
existing_routes = [
|
||||||
route for route in mcp._get_additional_http_routes()
|
route for route in mcp._get_additional_http_routes()
|
||||||
|
|
@ -610,6 +647,54 @@ Examples:
|
||||||
help="HTTP server port (overrides URL port). "
|
help="HTTP server port (overrides URL port). "
|
||||||
"Overrides UNITY_MCP_HTTP_PORT environment variable."
|
"Overrides UNITY_MCP_HTTP_PORT environment variable."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--http-remote-hosted",
|
||||||
|
action="store_true",
|
||||||
|
help="Treat HTTP transport as remotely hosted (forces explicit Unity instance selection). "
|
||||||
|
"Can also set via UNITY_MCP_HTTP_REMOTE_HOSTED=true."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key-validation-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
metavar="URL",
|
||||||
|
help="External URL to validate API keys (POST with {'api_key': '...'}). "
|
||||||
|
"Required when --http-remote-hosted is set. "
|
||||||
|
"Can also set via UNITY_MCP_API_KEY_VALIDATION_URL."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key-login-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
metavar="URL",
|
||||||
|
help="URL where users can obtain/manage API keys. "
|
||||||
|
"Returned by /api/auth/login-url endpoint. "
|
||||||
|
"Can also set via UNITY_MCP_API_KEY_LOGIN_URL."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key-cache-ttl",
|
||||||
|
type=float,
|
||||||
|
default=300.0,
|
||||||
|
metavar="SECONDS",
|
||||||
|
help="Cache TTL for validated API keys in seconds (default: 300). "
|
||||||
|
"Can also set via UNITY_MCP_API_KEY_CACHE_TTL."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key-service-token-header",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
metavar="HEADER",
|
||||||
|
help="Header name for service token sent to validation endpoint (e.g. X-Service-Token). "
|
||||||
|
"Can also set via UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key-service-token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
metavar="TOKEN",
|
||||||
|
help="Service token value sent to validation endpoint for server authentication. "
|
||||||
|
"WARNING: Prefer UNITY_MCP_API_KEY_SERVICE_TOKEN env var in production to avoid process listing exposure."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--unity-instance-token",
|
"--unity-instance-token",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -630,6 +715,7 @@ Examples:
|
||||||
"--project-scoped-tools",
|
"--project-scoped-tools",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Keep custom tools scoped to the active Unity project and enable the custom tools resource. "
|
help="Keep custom tools scoped to the active Unity project and enable the custom tools resource. "
|
||||||
|
"Can also set via UNITY_MCP_PROJECT_SCOPED_TOOLS=true."
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -641,10 +727,52 @@ Examples:
|
||||||
f"Using default Unity instance from command-line: {args.default_instance}")
|
f"Using default Unity instance from command-line: {args.default_instance}")
|
||||||
|
|
||||||
# Set transport mode
|
# Set transport mode
|
||||||
transport_mode = args.transport or os.environ.get(
|
config.transport_mode = args.transport or os.environ.get(
|
||||||
"UNITY_MCP_TRANSPORT", "stdio")
|
"UNITY_MCP_TRANSPORT", "stdio")
|
||||||
os.environ["UNITY_MCP_TRANSPORT"] = transport_mode
|
logger.info(f"Transport mode: {config.transport_mode}")
|
||||||
logger.info(f"Transport mode: {transport_mode}")
|
|
||||||
|
config.http_remote_hosted = (
|
||||||
|
bool(args.http_remote_hosted)
|
||||||
|
or os.environ.get("UNITY_MCP_HTTP_REMOTE_HOSTED", "").lower() in ("true", "1", "yes", "on")
|
||||||
|
)
|
||||||
|
|
||||||
|
# API key authentication configuration
|
||||||
|
config.api_key_validation_url = (
|
||||||
|
args.api_key_validation_url
|
||||||
|
or os.environ.get("UNITY_MCP_API_KEY_VALIDATION_URL")
|
||||||
|
)
|
||||||
|
config.api_key_login_url = (
|
||||||
|
args.api_key_login_url
|
||||||
|
or os.environ.get("UNITY_MCP_API_KEY_LOGIN_URL")
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
cache_ttl_env = os.environ.get("UNITY_MCP_API_KEY_CACHE_TTL")
|
||||||
|
config.api_key_cache_ttl = (
|
||||||
|
float(cache_ttl_env) if cache_ttl_env else args.api_key_cache_ttl
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
"Invalid UNITY_MCP_API_KEY_CACHE_TTL value, using default 300.0"
|
||||||
|
)
|
||||||
|
config.api_key_cache_ttl = 300.0
|
||||||
|
|
||||||
|
# Service token for authenticating to validation endpoint
|
||||||
|
config.api_key_service_token_header = (
|
||||||
|
args.api_key_service_token_header
|
||||||
|
or os.environ.get("UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER")
|
||||||
|
)
|
||||||
|
config.api_key_service_token = (
|
||||||
|
args.api_key_service_token
|
||||||
|
or os.environ.get("UNITY_MCP_API_KEY_SERVICE_TOKEN")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate: remote-hosted HTTP mode requires API key validation URL
|
||||||
|
if config.http_remote_hosted and config.transport_mode == "http" and not config.api_key_validation_url:
|
||||||
|
logger.error(
|
||||||
|
"--http-remote-hosted requires --api-key-validation-url or "
|
||||||
|
"UNITY_MCP_API_KEY_VALIDATION_URL environment variable"
|
||||||
|
)
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
http_url = os.environ.get("UNITY_MCP_HTTP_URL", args.http_url)
|
http_url = os.environ.get("UNITY_MCP_HTTP_URL", args.http_url)
|
||||||
parsed_url = urlparse(http_url)
|
parsed_url = urlparse(http_url)
|
||||||
|
|
@ -688,10 +816,14 @@ Examples:
|
||||||
if args.http_port:
|
if args.http_port:
|
||||||
logger.info(f"HTTP port override: {http_port}")
|
logger.info(f"HTTP port override: {http_port}")
|
||||||
|
|
||||||
mcp = create_mcp_server(args.project_scoped_tools)
|
project_scoped_tools = (
|
||||||
|
bool(args.project_scoped_tools)
|
||||||
|
or os.environ.get("UNITY_MCP_PROJECT_SCOPED_TOOLS", "").lower() in ("true", "1", "yes", "on")
|
||||||
|
)
|
||||||
|
mcp = create_mcp_server(project_scoped_tools)
|
||||||
|
|
||||||
# Determine transport mode
|
# Determine transport mode
|
||||||
if transport_mode == 'http':
|
if config.transport_mode == 'http':
|
||||||
# Use HTTP transport for FastMCP
|
# Use HTTP transport for FastMCP
|
||||||
transport = 'http'
|
transport = 'http'
|
||||||
# Use the parsed host and port from URL/args
|
# Use the parsed host and port from URL/args
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,235 @@
|
||||||
|
"""API Key validation service for remote-hosted mode."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger("mcp-for-unity-server")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of an API key validation."""
|
||||||
|
valid: bool
|
||||||
|
user_id: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
error: str | None = None
|
||||||
|
cacheable: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyService:
|
||||||
|
"""Service for validating API keys against an external auth endpoint.
|
||||||
|
|
||||||
|
Follows the class-level singleton pattern for global access by MCP tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: "ApiKeyService | None" = None
|
||||||
|
|
||||||
|
# Request defaults (sensible hardening)
|
||||||
|
REQUEST_TIMEOUT: float = 5.0
|
||||||
|
MAX_RETRIES: int = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
validation_url: str,
|
||||||
|
cache_ttl: float = 300.0,
|
||||||
|
service_token_header: str | None = None,
|
||||||
|
service_token: str | None = None,
|
||||||
|
):
|
||||||
|
"""Initialize the API key service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
validation_url: External URL to validate API keys (POST with {"api_key": "..."})
|
||||||
|
cache_ttl: Cache TTL for validated keys in seconds (default: 300)
|
||||||
|
service_token_header: Optional header name for service authentication (e.g. "X-Service-Token")
|
||||||
|
service_token: Optional token value for service authentication
|
||||||
|
"""
|
||||||
|
self._validation_url = validation_url
|
||||||
|
self._cache_ttl = cache_ttl
|
||||||
|
self._service_token_header = service_token_header
|
||||||
|
self._service_token = service_token
|
||||||
|
# Cache: api_key -> (valid, user_id, metadata, expires_at)
|
||||||
|
self._cache: dict[str, tuple[bool, str |
|
||||||
|
None, dict[str, Any] | None, float]] = {}
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
ApiKeyService._instance = self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "ApiKeyService":
|
||||||
|
"""Get the singleton instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the service has not been initialized.
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
raise RuntimeError("ApiKeyService not initialized")
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_initialized(cls) -> bool:
|
||||||
|
"""Check if the service has been initialized."""
|
||||||
|
return cls._instance is not None
|
||||||
|
|
||||||
|
async def validate(self, api_key: str) -> ValidationResult:
|
||||||
|
"""Validate an API key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with valid=True and user_id if valid,
|
||||||
|
or valid=False with error message if invalid.
|
||||||
|
"""
|
||||||
|
if not api_key:
|
||||||
|
return ValidationResult(valid=False, error="API key required")
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
async with self._cache_lock:
|
||||||
|
cached = self._cache.get(api_key)
|
||||||
|
if cached is not None:
|
||||||
|
valid, user_id, metadata, expires_at = cached
|
||||||
|
if time.time() < expires_at:
|
||||||
|
if valid:
|
||||||
|
return ValidationResult(valid=True, user_id=user_id, metadata=metadata)
|
||||||
|
else:
|
||||||
|
return ValidationResult(valid=False, error="Invalid API key")
|
||||||
|
else:
|
||||||
|
# Expired, remove from cache
|
||||||
|
del self._cache[api_key]
|
||||||
|
|
||||||
|
# Call external validation URL
|
||||||
|
result = await self._validate_external(api_key)
|
||||||
|
|
||||||
|
# Only cache definitive results (valid keys and confirmed-invalid keys).
|
||||||
|
# Transient failures (auth service unavailable, timeouts, etc.) should
|
||||||
|
# not be cached to avoid locking out users during service outages.
|
||||||
|
if result.cacheable:
|
||||||
|
async with self._cache_lock:
|
||||||
|
expires_at = time.time() + self._cache_ttl
|
||||||
|
self._cache[api_key] = (
|
||||||
|
result.valid,
|
||||||
|
result.user_id,
|
||||||
|
result.metadata,
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _validate_external(self, api_key: str) -> ValidationResult:
|
||||||
|
"""Call external validation endpoint.
|
||||||
|
|
||||||
|
Failure mode: fail closed (treat as invalid on errors).
|
||||||
|
"""
|
||||||
|
# Redact API key from logs
|
||||||
|
redacted_key = f"{api_key[:4]}...{api_key[-4:]}" if len(
|
||||||
|
api_key) > 8 else "***"
|
||||||
|
|
||||||
|
for attempt in range(self.MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
|
||||||
|
# Build request headers
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self._service_token_header and self._service_token:
|
||||||
|
headers[self._service_token_header] = self._service_token
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
self._validation_url,
|
||||||
|
json={"api_key": api_key},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if data.get("valid"):
|
||||||
|
return ValidationResult(
|
||||||
|
valid=True,
|
||||||
|
user_id=data.get("user_id"),
|
||||||
|
metadata=data.get("metadata"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ValidationResult(
|
||||||
|
valid=False,
|
||||||
|
error=data.get("error", "Invalid API key"),
|
||||||
|
)
|
||||||
|
elif response.status_code == 401:
|
||||||
|
return ValidationResult(valid=False, error="Invalid API key")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"API key validation returned status %d for key %s",
|
||||||
|
response.status_code,
|
||||||
|
redacted_key,
|
||||||
|
)
|
||||||
|
# Fail closed but don't cache (transient service error)
|
||||||
|
return ValidationResult(
|
||||||
|
valid=False,
|
||||||
|
error=f"Auth service error (status {response.status_code})",
|
||||||
|
cacheable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
if attempt < self.MAX_RETRIES:
|
||||||
|
logger.debug(
|
||||||
|
"API key validation timeout for key %s, retrying...",
|
||||||
|
redacted_key,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.1 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
logger.warning(
|
||||||
|
"API key validation timeout for key %s after %d attempts",
|
||||||
|
redacted_key,
|
||||||
|
attempt + 1,
|
||||||
|
)
|
||||||
|
return ValidationResult(
|
||||||
|
valid=False,
|
||||||
|
error="Auth service timeout",
|
||||||
|
cacheable=False,
|
||||||
|
)
|
||||||
|
except httpx.RequestError as exc:
|
||||||
|
if attempt < self.MAX_RETRIES:
|
||||||
|
logger.debug(
|
||||||
|
"API key validation request error for key %s: %s, retrying...",
|
||||||
|
redacted_key,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.1 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
logger.warning(
|
||||||
|
"API key validation request error for key %s: %s",
|
||||||
|
redacted_key,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return ValidationResult(
|
||||||
|
valid=False,
|
||||||
|
error="Auth service unavailable",
|
||||||
|
cacheable=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"Unexpected error validating API key %s: %s",
|
||||||
|
redacted_key,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return ValidationResult(
|
||||||
|
valid=False,
|
||||||
|
error="Auth service error",
|
||||||
|
cacheable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not reach here, but fail closed
|
||||||
|
return ValidationResult(valid=False, error="Auth service error", cacheable=False)
|
||||||
|
|
||||||
|
async def invalidate_cache(self, api_key: str) -> None:
|
||||||
|
"""Remove an API key from the cache."""
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._cache.pop(api_key, None)
|
||||||
|
|
||||||
|
async def clear_cache(self) -> None:
|
||||||
|
"""Clear all cached validations."""
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ApiKeyService", "ValidationResult"]
|
||||||
|
|
@ -5,12 +5,14 @@ from typing import Any
|
||||||
from fastmcp import Context
|
from fastmcp import Context
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
from models import MCPResponse
|
from models import MCPResponse
|
||||||
from services.registry import mcp_for_unity_resource
|
from services.registry import mcp_for_unity_resource
|
||||||
from services.tools import get_unity_instance_from_context
|
from services.tools import get_unity_instance_from_context
|
||||||
from services.state.external_changes_scanner import external_changes_scanner
|
from services.state.external_changes_scanner import external_changes_scanner
|
||||||
import transport.unity_transport as unity_transport
|
import transport.unity_transport as unity_transport
|
||||||
from transport.legacy.unity_connection import async_send_command_with_retry
|
from transport.legacy.unity_connection import async_send_command_with_retry
|
||||||
|
from transport.plugin_hub import PluginHub
|
||||||
|
|
||||||
|
|
||||||
class EditorStateUnity(BaseModel):
|
class EditorStateUnity(BaseModel):
|
||||||
|
|
@ -132,17 +134,15 @@ async def infer_single_instance_id(ctx: Context) -> str | None:
|
||||||
"""
|
"""
|
||||||
await ctx.info("If exactly one Unity instance is connected, return its Name@hash id.")
|
await ctx.info("If exactly one Unity instance is connected, return its Name@hash id.")
|
||||||
|
|
||||||
try:
|
transport = (config.transport_mode or "stdio").lower()
|
||||||
transport = unity_transport._current_transport()
|
|
||||||
except Exception:
|
|
||||||
transport = None
|
|
||||||
|
|
||||||
if transport == "http":
|
if transport == "http":
|
||||||
# HTTP/WebSocket transport: derive from PluginHub sessions.
|
# HTTP/WebSocket transport: derive from PluginHub sessions.
|
||||||
try:
|
try:
|
||||||
from transport.plugin_hub import PluginHub
|
# In remote-hosted mode, filter sessions by user_id
|
||||||
|
user_id = ctx.get_state(
|
||||||
sessions_data = await PluginHub.get_sessions()
|
"user_id") if config.http_remote_hosted else None
|
||||||
|
sessions_data = await PluginHub.get_sessions(user_id=user_id)
|
||||||
sessions = sessions_data.sessions if hasattr(
|
sessions = sessions_data.sessions if hasattr(
|
||||||
sessions_data, "sessions") else {}
|
sessions_data, "sessions") else {}
|
||||||
if isinstance(sessions, dict) and len(sessions) == 1:
|
if isinstance(sessions, dict) and len(sessions) == 1:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from fastmcp import Context
|
||||||
from services.registry import mcp_for_unity_resource
|
from services.registry import mcp_for_unity_resource
|
||||||
from transport.legacy.unity_connection import get_unity_connection_pool
|
from transport.legacy.unity_connection import get_unity_connection_pool
|
||||||
from transport.plugin_hub import PluginHub
|
from transport.plugin_hub import PluginHub
|
||||||
from transport.unity_transport import _current_transport
|
from core.config import config
|
||||||
|
|
||||||
|
|
||||||
@mcp_for_unity_resource(
|
@mcp_for_unity_resource(
|
||||||
|
|
@ -36,10 +36,13 @@ async def unity_instances(ctx: Context) -> dict[str, Any]:
|
||||||
await ctx.info("Listing Unity instances")
|
await ctx.info("Listing Unity instances")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
transport = _current_transport()
|
transport = (config.transport_mode or "stdio").lower()
|
||||||
if transport == "http":
|
if transport == "http":
|
||||||
# HTTP/WebSocket transport: query PluginHub
|
# HTTP/WebSocket transport: query PluginHub
|
||||||
sessions_data = await PluginHub.get_sessions()
|
# In remote-hosted mode, filter sessions by user_id
|
||||||
|
user_id = ctx.get_state(
|
||||||
|
"user_id") if config.http_remote_hosted else None
|
||||||
|
sessions_data = await PluginHub.get_sessions(user_id=user_id)
|
||||||
sessions = sessions_data.sessions
|
sessions = sessions_data.sessions
|
||||||
|
|
||||||
instances = []
|
instances = []
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from services.registry import mcp_for_unity_tool
|
||||||
from transport.legacy.unity_connection import get_unity_connection_pool
|
from transport.legacy.unity_connection import get_unity_connection_pool
|
||||||
from transport.unity_instance_middleware import get_unity_instance_middleware
|
from transport.unity_instance_middleware import get_unity_instance_middleware
|
||||||
from transport.plugin_hub import PluginHub
|
from transport.plugin_hub import PluginHub
|
||||||
from transport.unity_transport import _current_transport
|
from core.config import config
|
||||||
|
|
||||||
|
|
||||||
@mcp_for_unity_tool(
|
@mcp_for_unity_tool(
|
||||||
|
|
@ -21,11 +21,14 @@ async def set_active_instance(
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
instance: Annotated[str, "Target instance (Name@hash or hash prefix)"]
|
instance: Annotated[str, "Target instance (Name@hash or hash prefix)"]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
transport = _current_transport()
|
transport = (config.transport_mode or "stdio").lower()
|
||||||
|
|
||||||
# Discover running instances based on transport
|
# Discover running instances based on transport
|
||||||
if transport == "http":
|
if transport == "http":
|
||||||
sessions_data = await PluginHub.get_sessions()
|
# In remote-hosted mode, filter sessions by user_id
|
||||||
|
user_id = ctx.get_state(
|
||||||
|
"user_id") if config.http_remote_hosted else None
|
||||||
|
sessions_data = await PluginHub.get_sessions(user_id=user_id)
|
||||||
sessions = sessions_data.sessions
|
sessions = sessions_data.sessions
|
||||||
instances = []
|
instances = []
|
||||||
for session_id, session in sessions.items():
|
for session_id, session in sessions.items():
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,10 @@ from starlette.endpoints import WebSocketEndpoint
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
from core.config import config
|
from core.config import config
|
||||||
|
from core.constants import API_KEY_HEADER
|
||||||
from models.models import MCPResponse
|
from models.models import MCPResponse
|
||||||
from transport.plugin_registry import PluginRegistry
|
from transport.plugin_registry import PluginRegistry
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
from transport.models import (
|
from transport.models import (
|
||||||
WelcomeMessage,
|
WelcomeMessage,
|
||||||
RegisteredMessage,
|
RegisteredMessage,
|
||||||
|
|
@ -38,6 +40,22 @@ class NoUnitySessionError(RuntimeError):
|
||||||
"""Raised when no Unity plugins are available."""
|
"""Raised when no Unity plugins are available."""
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceSelectionRequiredError(RuntimeError):
|
||||||
|
"""Raised when the caller must explicitly select a Unity instance."""
|
||||||
|
|
||||||
|
_SELECTION_REQUIRED = (
|
||||||
|
"Unity instance selection is required. "
|
||||||
|
"Call set_active_instance with Name@hash from mcpforunity://instances."
|
||||||
|
)
|
||||||
|
_MULTIPLE_INSTANCES = (
|
||||||
|
"Multiple Unity instances are connected. "
|
||||||
|
"Call set_active_instance with Name@hash from mcpforunity://instances."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, message: str | None = None):
|
||||||
|
super().__init__(message or self._SELECTION_REQUIRED)
|
||||||
|
|
||||||
|
|
||||||
class PluginHub(WebSocketEndpoint):
|
class PluginHub(WebSocketEndpoint):
|
||||||
"""Manages persistent WebSocket connections to Unity plugins."""
|
"""Manages persistent WebSocket connections to Unity plugins."""
|
||||||
|
|
||||||
|
|
@ -77,6 +95,50 @@ class PluginHub(WebSocketEndpoint):
|
||||||
return cls._registry is not None and cls._lock is not None
|
return cls._registry is not None and cls._lock is not None
|
||||||
|
|
||||||
async def on_connect(self, websocket: WebSocket) -> None:
|
async def on_connect(self, websocket: WebSocket) -> None:
|
||||||
|
# Validate API key in remote-hosted mode (fail closed)
|
||||||
|
if config.http_remote_hosted:
|
||||||
|
if not ApiKeyService.is_initialized():
|
||||||
|
logger.debug(
|
||||||
|
"WebSocket connection rejected: auth service not initialized")
|
||||||
|
await websocket.close(code=1013, reason="Try again later")
|
||||||
|
return
|
||||||
|
|
||||||
|
api_key = websocket.headers.get(API_KEY_HEADER)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
logger.debug("WebSocket connection rejected: API key required")
|
||||||
|
await websocket.close(code=4401, reason="API key required")
|
||||||
|
return
|
||||||
|
|
||||||
|
service = ApiKeyService.get_instance()
|
||||||
|
result = await service.validate(api_key)
|
||||||
|
|
||||||
|
if not result.valid:
|
||||||
|
# Transient auth failures are retryable (1013)
|
||||||
|
if result.error and any(
|
||||||
|
indicator in result.error.lower()
|
||||||
|
for indicator in ("unavailable", "timeout", "service error")
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"WebSocket connection rejected: auth service unavailable")
|
||||||
|
await websocket.close(code=1013, reason="Try again later")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("WebSocket connection rejected: invalid API key")
|
||||||
|
await websocket.close(code=4403, reason="Invalid API key")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Both valid and user_id must be present to accept
|
||||||
|
if not result.user_id:
|
||||||
|
logger.debug(
|
||||||
|
"WebSocket connection rejected: validated key missing user_id")
|
||||||
|
await websocket.close(code=4403, reason="Invalid API key")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store user_id in websocket state for later use during registration
|
||||||
|
websocket.state.user_id = result.user_id
|
||||||
|
websocket.state.api_key_metadata = result.metadata
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
msg = WelcomeMessage(
|
msg = WelcomeMessage(
|
||||||
serverTimeout=self.SERVER_TIMEOUT,
|
serverTimeout=self.SERVER_TIMEOUT,
|
||||||
|
|
@ -217,10 +279,15 @@ class PluginHub(WebSocketEndpoint):
|
||||||
cls._pending.pop(command_id, None)
|
cls._pending.pop(command_id, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_sessions(cls) -> SessionList:
|
async def get_sessions(cls, user_id: str | None = None) -> SessionList:
|
||||||
|
"""Get all active plugin sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: If provided (remote-hosted mode), only return sessions for this user.
|
||||||
|
"""
|
||||||
if cls._registry is None:
|
if cls._registry is None:
|
||||||
return SessionList(sessions={})
|
return SessionList(sessions={})
|
||||||
sessions = await cls._registry.list_sessions()
|
sessions = await cls._registry.list_sessions(user_id=user_id)
|
||||||
return SessionList(
|
return SessionList(
|
||||||
sessions={
|
sessions={
|
||||||
session_id: SessionDetails(
|
session_id: SessionDetails(
|
||||||
|
|
@ -286,14 +353,22 @@ class PluginHub(WebSocketEndpoint):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Plugin registration missing project_hash")
|
"Plugin registration missing project_hash")
|
||||||
|
|
||||||
|
# Get user_id from websocket state (set during API key validation)
|
||||||
|
user_id = getattr(websocket.state, "user_id", None)
|
||||||
|
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
# Inform the plugin of its assigned session ID
|
# Inform the plugin of its assigned session ID
|
||||||
response = RegisteredMessage(session_id=session_id)
|
response = RegisteredMessage(session_id=session_id)
|
||||||
await websocket.send_json(response.model_dump())
|
await websocket.send_json(response.model_dump())
|
||||||
|
|
||||||
session = await registry.register(session_id, project_name, project_hash, unity_version, project_path)
|
session = await registry.register(session_id, project_name, project_hash, unity_version, project_path, user_id=user_id)
|
||||||
async with lock:
|
async with lock:
|
||||||
cls._connections[session.session_id] = websocket
|
cls._connections[session.session_id] = websocket
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
logger.info(
|
||||||
|
f"Plugin registered: {project_name} ({project_hash}) for user {user_id}")
|
||||||
|
else:
|
||||||
logger.info(f"Plugin registered: {project_name} ({project_hash})")
|
logger.info(f"Plugin registered: {project_name} ({project_hash})")
|
||||||
|
|
||||||
async def _handle_register_tools(self, websocket: WebSocket, payload: RegisterToolsMessage) -> None:
|
async def _handle_register_tools(self, websocket: WebSocket, payload: RegisterToolsMessage) -> None:
|
||||||
|
|
@ -375,13 +450,17 @@ class PluginHub(WebSocketEndpoint):
|
||||||
# Session resolution helpers
|
# Session resolution helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _resolve_session_id(cls, unity_instance: str | None) -> str:
|
async def _resolve_session_id(cls, unity_instance: str | None, user_id: str | None = None) -> str:
|
||||||
"""Resolve a project hash (Unity instance id) to an active plugin session.
|
"""Resolve a project hash (Unity instance id) to an active plugin session.
|
||||||
|
|
||||||
During Unity domain reloads the plugin's WebSocket session is torn down
|
During Unity domain reloads the plugin's WebSocket session is torn down
|
||||||
and reconnected shortly afterwards. Instead of failing immediately when
|
and reconnected shortly afterwards. Instead of failing immediately when
|
||||||
no sessions are available, we wait for a bounded period for a plugin
|
no sessions are available, we wait for a bounded period for a plugin
|
||||||
to reconnect so in-flight MCP calls can succeed transparently.
|
to reconnect so in-flight MCP calls can succeed transparently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unity_instance: Target instance (Name@hash or hash)
|
||||||
|
user_id: User ID from API key validation (for remote-hosted mode session isolation)
|
||||||
"""
|
"""
|
||||||
if cls._registry is None:
|
if cls._registry is None:
|
||||||
raise RuntimeError("Plugin registry not configured")
|
raise RuntimeError("Plugin registry not configured")
|
||||||
|
|
@ -411,24 +490,35 @@ class PluginHub(WebSocketEndpoint):
|
||||||
else:
|
else:
|
||||||
target_hash = unity_instance
|
target_hash = unity_instance
|
||||||
|
|
||||||
async def _try_once() -> tuple[str | None, int]:
|
async def _try_once() -> tuple[str | None, int, bool]:
|
||||||
|
explicit_required = config.http_remote_hosted
|
||||||
# Prefer a specific Unity instance if one was requested
|
# Prefer a specific Unity instance if one was requested
|
||||||
if target_hash:
|
if target_hash:
|
||||||
|
# In remote-hosted mode with user_id, use user-scoped lookup
|
||||||
|
if config.http_remote_hosted and user_id:
|
||||||
|
session_id = await cls._registry.get_session_id_by_hash(target_hash, user_id)
|
||||||
|
sessions = await cls._registry.list_sessions(user_id=user_id)
|
||||||
|
else:
|
||||||
session_id = await cls._registry.get_session_id_by_hash(target_hash)
|
session_id = await cls._registry.get_session_id_by_hash(target_hash)
|
||||||
sessions = await cls._registry.list_sessions()
|
sessions = await cls._registry.list_sessions(user_id=user_id)
|
||||||
return session_id, len(sessions)
|
return session_id, len(sessions), explicit_required
|
||||||
|
|
||||||
# No target provided: determine if we can auto-select
|
# No target provided: determine if we can auto-select
|
||||||
sessions = await cls._registry.list_sessions()
|
# In remote-hosted mode, filter sessions by user_id
|
||||||
|
sessions = await cls._registry.list_sessions(user_id=user_id)
|
||||||
count = len(sessions)
|
count = len(sessions)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return None, count
|
return None, count, explicit_required
|
||||||
|
if explicit_required:
|
||||||
|
return None, count, explicit_required
|
||||||
if count == 1:
|
if count == 1:
|
||||||
return next(iter(sessions.keys())), count
|
return next(iter(sessions.keys())), count, explicit_required
|
||||||
# Multiple sessions but no explicit target is ambiguous
|
# Multiple sessions but no explicit target is ambiguous
|
||||||
return None, count
|
return None, count, explicit_required
|
||||||
|
|
||||||
session_id, session_count = await _try_once()
|
session_id, session_count, explicit_required = await _try_once()
|
||||||
|
if session_id is None and explicit_required and not target_hash and session_count > 0:
|
||||||
|
raise InstanceSelectionRequiredError()
|
||||||
deadline = time.monotonic() + max_wait_s
|
deadline = time.monotonic() + max_wait_s
|
||||||
wait_started = None
|
wait_started = None
|
||||||
|
|
||||||
|
|
@ -436,10 +526,10 @@ class PluginHub(WebSocketEndpoint):
|
||||||
# wait politely for a session to appear before surfacing an error.
|
# wait politely for a session to appear before surfacing an error.
|
||||||
while session_id is None and time.monotonic() < deadline:
|
while session_id is None and time.monotonic() < deadline:
|
||||||
if not target_hash and session_count > 1:
|
if not target_hash and session_count > 1:
|
||||||
raise RuntimeError(
|
raise InstanceSelectionRequiredError(
|
||||||
"Multiple Unity instances are connected. "
|
InstanceSelectionRequiredError._MULTIPLE_INSTANCES)
|
||||||
"Call set_active_instance with Name@hash from mcpforunity://instances."
|
if session_id is None and explicit_required and not target_hash and session_count > 0:
|
||||||
)
|
raise InstanceSelectionRequiredError()
|
||||||
if wait_started is None:
|
if wait_started is None:
|
||||||
wait_started = time.monotonic()
|
wait_started = time.monotonic()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -448,7 +538,7 @@ class PluginHub(WebSocketEndpoint):
|
||||||
max_wait_s,
|
max_wait_s,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(sleep_seconds)
|
await asyncio.sleep(sleep_seconds)
|
||||||
session_id, session_count = await _try_once()
|
session_id, session_count, explicit_required = await _try_once()
|
||||||
|
|
||||||
if session_id is not None and wait_started is not None:
|
if session_id is not None and wait_started is not None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -457,10 +547,11 @@ class PluginHub(WebSocketEndpoint):
|
||||||
unity_instance or "default",
|
unity_instance or "default",
|
||||||
)
|
)
|
||||||
if session_id is None and not target_hash and session_count > 1:
|
if session_id is None and not target_hash and session_count > 1:
|
||||||
raise RuntimeError(
|
raise InstanceSelectionRequiredError(
|
||||||
"Multiple Unity instances are connected. "
|
InstanceSelectionRequiredError._MULTIPLE_INSTANCES)
|
||||||
"Call set_active_instance with Name@hash from mcpforunity://instances."
|
|
||||||
)
|
if session_id is None and explicit_required and not target_hash and session_count > 0:
|
||||||
|
raise InstanceSelectionRequiredError()
|
||||||
|
|
||||||
if session_id is None:
|
if session_id is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -481,9 +572,18 @@ class PluginHub(WebSocketEndpoint):
|
||||||
unity_instance: str | None,
|
unity_instance: str | None,
|
||||||
command_type: str,
|
command_type: str,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
|
user_id: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
"""Send a command to a Unity instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unity_instance: Target instance (Name@hash or hash)
|
||||||
|
command_type: Command type to execute
|
||||||
|
params: Command parameters
|
||||||
|
user_id: User ID for session isolation in remote-hosted mode
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
session_id = await cls._resolve_session_id(unity_instance)
|
session_id = await cls._resolve_session_id(unity_instance, user_id=user_id)
|
||||||
except NoUnitySessionError:
|
except NoUnitySessionError:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Unity session unavailable; returning retry: command=%s instance=%s",
|
"Unity session unavailable; returning retry: command=%s instance=%s",
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
from models.models import ToolDefinitionModel
|
from models.models import ToolDefinitionModel
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,7 +23,9 @@ class PluginSession:
|
||||||
connected_at: datetime
|
connected_at: datetime
|
||||||
tools: dict[str, ToolDefinitionModel] = field(default_factory=dict)
|
tools: dict[str, ToolDefinitionModel] = field(default_factory=dict)
|
||||||
project_id: str | None = None
|
project_id: str | None = None
|
||||||
project_path: str | None = None # Full path to project root (for focus nudging)
|
# Full path to project root (for focus nudging)
|
||||||
|
project_path: str | None = None
|
||||||
|
user_id: str | None = None # Associated user id (None for local mode)
|
||||||
|
|
||||||
|
|
||||||
class PluginRegistry:
|
class PluginRegistry:
|
||||||
|
|
@ -31,11 +34,17 @@ class PluginRegistry:
|
||||||
The registry is optimised for quick lookup by either ``session_id`` or
|
The registry is optimised for quick lookup by either ``session_id`` or
|
||||||
``project_hash`` (which is used as the canonical "instance id" across the
|
``project_hash`` (which is used as the canonical "instance id" across the
|
||||||
HTTP command routing stack).
|
HTTP command routing stack).
|
||||||
|
|
||||||
|
In remote-hosted mode, sessions are scoped by (user_id, project_hash) composite key
|
||||||
|
to ensure session isolation between users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._sessions: dict[str, PluginSession] = {}
|
self._sessions: dict[str, PluginSession] = {}
|
||||||
|
# In local mode: project_hash -> session_id
|
||||||
|
# In remote mode: (user_id, project_hash) -> session_id
|
||||||
self._hash_to_session: dict[str, str] = {}
|
self._hash_to_session: dict[str, str] = {}
|
||||||
|
self._user_hash_to_session: dict[tuple[str, str], str] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def register(
|
async def register(
|
||||||
|
|
@ -45,13 +54,16 @@ class PluginRegistry:
|
||||||
project_hash: str,
|
project_hash: str,
|
||||||
unity_version: str,
|
unity_version: str,
|
||||||
project_path: str | None = None,
|
project_path: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> PluginSession:
|
) -> PluginSession:
|
||||||
"""Register (or replace) a plugin session.
|
"""Register (or replace) a plugin session.
|
||||||
|
|
||||||
If an existing session already claims the same ``project_hash`` it will be
|
If an existing session already claims the same ``project_hash`` (and ``user_id``
|
||||||
replaced, ensuring that reconnect scenarios always map to the latest
|
in remote-hosted mode) it will be replaced, ensuring that reconnect scenarios
|
||||||
WebSocket connection.
|
always map to the latest WebSocket connection.
|
||||||
"""
|
"""
|
||||||
|
if config.http_remote_hosted and not user_id:
|
||||||
|
raise ValueError("user_id is required in remote-hosted mode")
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
@ -63,15 +75,26 @@ class PluginRegistry:
|
||||||
registered_at=now,
|
registered_at=now,
|
||||||
connected_at=now,
|
connected_at=now,
|
||||||
project_path=project_path,
|
project_path=project_path,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove old mapping for this hash if it existed under a different session
|
# Remove old mapping for this hash if it existed under a different session
|
||||||
|
if user_id:
|
||||||
|
# Remote-hosted mode: use composite key (user_id, project_hash)
|
||||||
|
composite_key = (user_id, project_hash)
|
||||||
|
previous_session_id = self._user_hash_to_session.get(
|
||||||
|
composite_key)
|
||||||
|
if previous_session_id and previous_session_id != session_id:
|
||||||
|
self._sessions.pop(previous_session_id, None)
|
||||||
|
self._user_hash_to_session[composite_key] = session_id
|
||||||
|
else:
|
||||||
|
# Local mode: use project_hash only
|
||||||
previous_session_id = self._hash_to_session.get(project_hash)
|
previous_session_id = self._hash_to_session.get(project_hash)
|
||||||
if previous_session_id and previous_session_id != session_id:
|
if previous_session_id and previous_session_id != session_id:
|
||||||
self._sessions.pop(previous_session_id, None)
|
self._sessions.pop(previous_session_id, None)
|
||||||
|
self._hash_to_session[project_hash] = session_id
|
||||||
|
|
||||||
self._sessions[session_id] = session
|
self._sessions[session_id] = session
|
||||||
self._hash_to_session[project_hash] = session_id
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
async def touch(self, session_id: str) -> None:
|
async def touch(self, session_id: str) -> None:
|
||||||
|
|
@ -87,12 +110,21 @@ class PluginRegistry:
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
session = self._sessions.pop(session_id, None)
|
session = self._sessions.pop(session_id, None)
|
||||||
if session and session.project_hash in self._hash_to_session:
|
if session:
|
||||||
# Only delete the mapping if it still points at the removed session.
|
# Clean up hash mappings
|
||||||
|
if session.project_hash in self._hash_to_session:
|
||||||
mapped = self._hash_to_session.get(session.project_hash)
|
mapped = self._hash_to_session.get(session.project_hash)
|
||||||
if mapped == session_id:
|
if mapped == session_id:
|
||||||
del self._hash_to_session[session.project_hash]
|
del self._hash_to_session[session.project_hash]
|
||||||
|
|
||||||
|
# Clean up user-scoped mappings
|
||||||
|
if session.user_id:
|
||||||
|
composite_key = (session.user_id, session.project_hash)
|
||||||
|
if composite_key in self._user_hash_to_session:
|
||||||
|
mapped = self._user_hash_to_session.get(composite_key)
|
||||||
|
if mapped == session_id:
|
||||||
|
del self._user_hash_to_session[composite_key]
|
||||||
|
|
||||||
async def register_tools_for_session(self, session_id: str, tools: list[ToolDefinitionModel]) -> None:
|
async def register_tools_for_session(self, session_id: str, tools: list[ToolDefinitionModel]) -> None:
|
||||||
"""Register tools for a specific session."""
|
"""Register tools for a specific session."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
|
|
@ -110,17 +142,41 @@ class PluginRegistry:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
return self._sessions.get(session_id)
|
return self._sessions.get(session_id)
|
||||||
|
|
||||||
async def get_session_id_by_hash(self, project_hash: str) -> str | None:
|
async def get_session_id_by_hash(self, project_hash: str, user_id: str | None = None) -> str | None:
|
||||||
"""Resolve a ``project_hash`` (Unity instance id) to a session id."""
|
"""Resolve a ``project_hash`` (Unity instance id) to a session id."""
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
async with self._lock:
|
||||||
|
return self._user_hash_to_session.get((user_id, project_hash))
|
||||||
|
else:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
return self._hash_to_session.get(project_hash)
|
return self._hash_to_session.get(project_hash)
|
||||||
|
|
||||||
async def list_sessions(self) -> dict[str, PluginSession]:
|
async def list_sessions(self, user_id: str | None = None) -> dict[str, PluginSession]:
|
||||||
"""Return a shallow copy of all known sessions."""
|
"""Return a shallow copy of sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: If provided, only return sessions for this user (remote-hosted mode).
|
||||||
|
If None, return all sessions (local mode only).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If ``user_id`` is None while running in remote-hosted mode.
|
||||||
|
This prevents accidentally leaking sessions across users.
|
||||||
|
"""
|
||||||
|
if user_id is None and config.http_remote_hosted:
|
||||||
|
raise ValueError(
|
||||||
|
"list_sessions requires user_id in remote-hosted mode"
|
||||||
|
)
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
|
if user_id is None:
|
||||||
return dict(self._sessions)
|
return dict(self._sessions)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
sid: session
|
||||||
|
for sid, session in self._sessions.items()
|
||||||
|
if session.user_id == user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PluginRegistry", "PluginSession"]
|
__all__ = ["PluginRegistry", "PluginSession"]
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import logging
|
||||||
|
|
||||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
from transport.plugin_hub import PluginHub
|
from transport.plugin_hub import PluginHub
|
||||||
|
|
||||||
logger = logging.getLogger("mcp-for-unity-server")
|
logger = logging.getLogger("mcp-for-unity-server")
|
||||||
|
|
@ -32,7 +33,12 @@ def get_unity_instance_middleware() -> 'UnityInstanceMiddleware':
|
||||||
|
|
||||||
|
|
||||||
def set_unity_instance_middleware(middleware: 'UnityInstanceMiddleware') -> None:
|
def set_unity_instance_middleware(middleware: 'UnityInstanceMiddleware') -> None:
|
||||||
"""Set the global Unity instance middleware (called during server initialization)."""
|
"""Replace the global middleware instance.
|
||||||
|
|
||||||
|
This is a test seam: production code uses ``get_unity_instance_middleware()``
|
||||||
|
which lazy-initialises the singleton. Tests call this function to inject a
|
||||||
|
mock or pre-configured middleware before exercising tool/resource code.
|
||||||
|
"""
|
||||||
global _unity_instance_middleware
|
global _unity_instance_middleware
|
||||||
_unity_instance_middleware = middleware
|
_unity_instance_middleware = middleware
|
||||||
|
|
||||||
|
|
@ -55,13 +61,18 @@ class UnityInstanceMiddleware(Middleware):
|
||||||
Derive a stable key for the calling session.
|
Derive a stable key for the calling session.
|
||||||
|
|
||||||
Prioritizes client_id for stability.
|
Prioritizes client_id for stability.
|
||||||
If client_id is missing, falls back to 'global' (assuming single-user local mode),
|
In remote-hosted mode, falls back to user_id for session isolation.
|
||||||
ignoring session_id which can be unstable in some transports/clients.
|
Otherwise falls back to 'global' (assuming single-user local mode).
|
||||||
"""
|
"""
|
||||||
client_id = getattr(ctx, "client_id", None)
|
client_id = getattr(ctx, "client_id", None)
|
||||||
if isinstance(client_id, str) and client_id:
|
if isinstance(client_id, str) and client_id:
|
||||||
return client_id
|
return client_id
|
||||||
|
|
||||||
|
# In remote-hosted mode, use user_id so different users get isolated instance selections
|
||||||
|
user_id = ctx.get_state("user_id")
|
||||||
|
if isinstance(user_id, str) and user_id:
|
||||||
|
return f"user:{user_id}"
|
||||||
|
|
||||||
# Fallback to global for local dev stability
|
# Fallback to global for local dev stability
|
||||||
return "global"
|
return "global"
|
||||||
|
|
||||||
|
|
@ -92,10 +103,10 @@ class UnityInstanceMiddleware(Middleware):
|
||||||
to stick for subsequent tool/resource calls in the same session.
|
to stick for subsequent tool/resource calls in the same session.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Import here to avoid circular dependencies / optional transport modules.
|
transport = (config.transport_mode or "stdio").lower()
|
||||||
from transport.unity_transport import _current_transport
|
# This implicit behavior works well for solo-users, but is dangerous for multi-user setups
|
||||||
|
if transport == "http" and config.http_remote_hosted:
|
||||||
transport = _current_transport()
|
return None
|
||||||
if PluginHub.is_configured():
|
if PluginHub.is_configured():
|
||||||
try:
|
try:
|
||||||
sessions_data = await PluginHub.get_sessions()
|
sessions_data = await PluginHub.get_sessions()
|
||||||
|
|
@ -172,10 +183,27 @@ class UnityInstanceMiddleware(Middleware):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _resolve_user_id(self) -> str | None:
|
||||||
|
"""Extract user_id from the current HTTP request's API key."""
|
||||||
|
if not config.http_remote_hosted:
|
||||||
|
return None
|
||||||
|
# Lazy import to avoid circular dependencies (same pattern as _maybe_autoselect_instance).
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
return await _resolve_user_id_from_request()
|
||||||
|
|
||||||
async def _inject_unity_instance(self, context: MiddlewareContext) -> None:
|
async def _inject_unity_instance(self, context: MiddlewareContext) -> None:
|
||||||
"""Inject active Unity instance into context if available."""
|
"""Inject active Unity instance and user_id into context if available."""
|
||||||
ctx = context.fastmcp_context
|
ctx = context.fastmcp_context
|
||||||
|
|
||||||
|
# Resolve user_id from the HTTP request's API key header
|
||||||
|
user_id = await self._resolve_user_id()
|
||||||
|
if config.http_remote_hosted and user_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"API key authentication required. Provide a valid X-API-Key header."
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
ctx.set_state("user_id", user_id)
|
||||||
|
|
||||||
active_instance = self.get_active_instance(ctx)
|
active_instance = self.get_active_instance(ctx)
|
||||||
if not active_instance:
|
if not active_instance:
|
||||||
active_instance = await self._maybe_autoselect_instance(ctx)
|
active_instance = await self._maybe_autoselect_instance(ctx)
|
||||||
|
|
@ -193,7 +221,8 @@ class UnityInstanceMiddleware(Middleware):
|
||||||
# resolving session_id might fail if the plugin disconnected
|
# resolving session_id might fail if the plugin disconnected
|
||||||
# We only need session_id for HTTP transport routing.
|
# We only need session_id for HTTP transport routing.
|
||||||
# For stdio, we just need the instance ID.
|
# For stdio, we just need the instance ID.
|
||||||
session_id = await PluginHub._resolve_session_id(active_instance)
|
# Pass user_id for remote-hosted mode session isolation
|
||||||
|
session_id = await PluginHub._resolve_session_id(active_instance, user_id=user_id)
|
||||||
except (ConnectionError, ValueError, KeyError, TimeoutError) as exc:
|
except (ConnectionError, ValueError, KeyError, TimeoutError) as exc:
|
||||||
# If resolution fails, it means the Unity instance is not reachable via HTTP/WS.
|
# If resolution fails, it means the Unity instance is not reachable via HTTP/WS.
|
||||||
# If we are in stdio mode, this might still be fine if the user is just setting state?
|
# If we are in stdio mode, this might still be fine if the user is just setting state?
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,49 @@
|
||||||
"""Transport helpers for routing commands to Unity."""
|
"""Transport helpers for routing commands to Unity."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import logging
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
from typing import Awaitable, Callable, TypeVar
|
from typing import Awaitable, Callable, TypeVar
|
||||||
|
|
||||||
from fastmcp import Context
|
|
||||||
|
|
||||||
from transport.plugin_hub import PluginHub
|
from transport.plugin_hub import PluginHub
|
||||||
|
from core.config import config
|
||||||
|
from core.constants import API_KEY_HEADER
|
||||||
|
from services.api_key_service import ApiKeyService
|
||||||
from models.models import MCPResponse
|
from models.models import MCPResponse
|
||||||
from models.unity_response import normalize_unity_response
|
from models.unity_response import normalize_unity_response
|
||||||
from services.tools import get_unity_instance_from_context
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
logger = logging.getLogger("mcp-for-unity-server")
|
||||||
|
|
||||||
|
|
||||||
def _is_http_transport() -> bool:
|
def _is_http_transport() -> bool:
|
||||||
return os.environ.get("UNITY_MCP_TRANSPORT", "stdio").lower() == "http"
|
return config.transport_mode.lower() == "http"
|
||||||
|
|
||||||
|
|
||||||
def _current_transport() -> str:
|
async def _resolve_user_id_from_request() -> str | None:
|
||||||
"""Expose the active transport mode as a simple string identifier."""
|
"""Extract user_id from the current HTTP request's API key header."""
|
||||||
return "http" if _is_http_transport() else "stdio"
|
if not config.http_remote_hosted:
|
||||||
|
return None
|
||||||
|
if not ApiKeyService.is_initialized():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from fastmcp.server.dependencies import get_http_headers
|
||||||
|
headers = get_http_headers(include_all=True)
|
||||||
|
api_key = headers.get(API_KEY_HEADER.lower())
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
service = ApiKeyService.get_instance()
|
||||||
|
result = await service.validate(api_key)
|
||||||
|
return result.user_id if result.valid else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to resolve user_id from HTTP request: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def send_with_unity_instance(
|
async def send_with_unity_instance(
|
||||||
send_fn: Callable[..., Awaitable[T]],
|
send_fn: Callable[..., Awaitable[T]],
|
||||||
unity_instance: str | None,
|
unity_instance: str | None,
|
||||||
*args,
|
*args,
|
||||||
|
user_id: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
if _is_http_transport():
|
if _is_http_transport():
|
||||||
|
|
@ -41,11 +56,27 @@ async def send_with_unity_instance(
|
||||||
if not isinstance(params, dict):
|
if not isinstance(params, dict):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Command parameters must be a dict for HTTP transport")
|
"Command parameters must be a dict for HTTP transport")
|
||||||
|
|
||||||
|
# Auto-resolve user_id from HTTP request API key (remote-hosted mode)
|
||||||
|
if user_id is None:
|
||||||
|
user_id = await _resolve_user_id_from_request()
|
||||||
|
|
||||||
|
# Auth check
|
||||||
|
if config.http_remote_hosted and not user_id:
|
||||||
|
return normalize_unity_response(
|
||||||
|
MCPResponse(
|
||||||
|
success=False,
|
||||||
|
error="auth_required",
|
||||||
|
message="API key required",
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = await PluginHub.send_command_for_instance(
|
raw = await PluginHub.send_command_for_instance(
|
||||||
unity_instance,
|
unity_instance,
|
||||||
command_type,
|
command_type,
|
||||||
params,
|
params,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
return normalize_unity_response(raw)
|
return normalize_unity_response(raw)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Pytest configuration for unity-mcp tests."""
|
"""Pytest configuration for unity-mcp tests."""
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -58,3 +59,29 @@ def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
|
||||||
|
|
||||||
# Reorder: characterization/unit tests first, then integration tests
|
# Reorder: characterization/unit tests first, then integration tests
|
||||||
items[:] = other_tests + integration_tests
|
items[:] = other_tests + integration_tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def restore_global_config():
|
||||||
|
"""Restore global config/env mutations between tests."""
|
||||||
|
from core.config import config as global_config
|
||||||
|
|
||||||
|
prior_env = os.environ.get("UNITY_MCP_TRANSPORT")
|
||||||
|
prior = {
|
||||||
|
"transport_mode": global_config.transport_mode,
|
||||||
|
"http_remote_hosted": global_config.http_remote_hosted,
|
||||||
|
"api_key_validation_url": global_config.api_key_validation_url,
|
||||||
|
"api_key_login_url": global_config.api_key_login_url,
|
||||||
|
"api_key_cache_ttl": global_config.api_key_cache_ttl,
|
||||||
|
"api_key_service_token_header": global_config.api_key_service_token_header,
|
||||||
|
"api_key_service_token": global_config.api_key_service_token,
|
||||||
|
}
|
||||||
|
yield
|
||||||
|
|
||||||
|
if prior_env is None:
|
||||||
|
os.environ.pop("UNITY_MCP_TRANSPORT", None)
|
||||||
|
else:
|
||||||
|
os.environ["UNITY_MCP_TRANSPORT"] = prior_env
|
||||||
|
|
||||||
|
for key, value in prior.items():
|
||||||
|
setattr(global_config, key, value)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,456 @@
|
||||||
|
"""Tests for ApiKeyService: validation, caching, retries, and singleton lifecycle."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.api_key_service import ApiKeyService, ValidationResult
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_singleton():
|
||||||
|
"""Reset the ApiKeyService singleton between tests."""
|
||||||
|
ApiKeyService._instance = None
|
||||||
|
yield
|
||||||
|
ApiKeyService._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
def _make_service(
|
||||||
|
validation_url="https://auth.example.com/validate",
|
||||||
|
cache_ttl=300.0,
|
||||||
|
service_token_header=None,
|
||||||
|
service_token=None,
|
||||||
|
):
|
||||||
|
return ApiKeyService(
|
||||||
|
validation_url=validation_url,
|
||||||
|
cache_ttl=cache_ttl,
|
||||||
|
service_token_header=service_token_header,
|
||||||
|
service_token=service_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(status_code=200, json_data=None):
|
||||||
|
resp = MagicMock(spec=httpx.Response)
|
||||||
|
resp.status_code = status_code
|
||||||
|
resp.json.return_value = json_data or {}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Singleton lifecycle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonLifecycle:
|
||||||
|
def test_get_instance_before_init_raises(self):
|
||||||
|
with pytest.raises(RuntimeError, match="not initialized"):
|
||||||
|
ApiKeyService.get_instance()
|
||||||
|
|
||||||
|
def test_is_initialized_false_before_init(self):
|
||||||
|
assert ApiKeyService.is_initialized() is False
|
||||||
|
|
||||||
|
def test_is_initialized_true_after_init(self):
|
||||||
|
_make_service()
|
||||||
|
assert ApiKeyService.is_initialized() is True
|
||||||
|
|
||||||
|
def test_get_instance_returns_service(self):
|
||||||
|
svc = _make_service()
|
||||||
|
assert ApiKeyService.get_instance() is svc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Basic validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBasicValidation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_key(self):
|
||||||
|
svc = _make_service()
|
||||||
|
mock_resp = _mock_response(
|
||||||
|
200, {"valid": True, "user_id": "user-1", "metadata": {"plan": "pro"}})
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = AsyncMock(return_value=mock_resp)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-valid-key-12345678")
|
||||||
|
|
||||||
|
assert result.valid is True
|
||||||
|
assert result.user_id == "user-1"
|
||||||
|
assert result.metadata == {"plan": "pro"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_key_200_body(self):
|
||||||
|
svc = _make_service()
|
||||||
|
mock_resp = _mock_response(
|
||||||
|
200, {"valid": False, "error": "Key revoked"})
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = AsyncMock(return_value=mock_resp)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-invalid-key-1234")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert result.error == "Key revoked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_key_401_status(self):
|
||||||
|
svc = _make_service()
|
||||||
|
mock_resp = _mock_response(401)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = AsyncMock(return_value=mock_resp)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-bad-key-12345678")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "Invalid API key" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_key_fast_path(self):
|
||||||
|
svc = _make_service()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
result = await svc.validate("")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "required" in result.error.lower()
|
||||||
|
# No HTTP call should have been made
|
||||||
|
MockClient.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Caching
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCaching:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_valid_key(self):
|
||||||
|
svc = _make_service(cache_ttl=300.0)
|
||||||
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def counting_post(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = counting_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
r1 = await svc.validate("test-cached-valid-key1")
|
||||||
|
r2 = await svc.validate("test-cached-valid-key1")
|
||||||
|
|
||||||
|
assert r1.valid is True
|
||||||
|
assert r2.valid is True
|
||||||
|
assert r2.user_id == "u1"
|
||||||
|
assert call_count == 1 # Only one HTTP call
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_invalid_key(self):
|
||||||
|
svc = _make_service(cache_ttl=300.0)
|
||||||
|
mock_resp = _mock_response(200, {"valid": False, "error": "bad"})
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def counting_post(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = counting_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
r1 = await svc.validate("test-cached-bad-key12")
|
||||||
|
r2 = await svc.validate("test-cached-bad-key12")
|
||||||
|
|
||||||
|
assert r1.valid is False
|
||||||
|
assert r2.valid is False
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expiry(self):
|
||||||
|
svc = _make_service(cache_ttl=1.0) # 1 second TTL
|
||||||
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def counting_post(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = counting_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
await svc.validate("test-expiry-key-12345")
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Manually expire the cache entry by manipulating the stored tuple
|
||||||
|
async with svc._cache_lock:
|
||||||
|
key = "test-expiry-key-12345"
|
||||||
|
valid, user_id, metadata, _expires = svc._cache[key]
|
||||||
|
svc._cache[key] = (valid, user_id, metadata, time.time() - 1)
|
||||||
|
|
||||||
|
await svc.validate("test-expiry-key-12345")
|
||||||
|
assert call_count == 2 # Had to re-validate
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_cache(self):
|
||||||
|
svc = _make_service(cache_ttl=300.0)
|
||||||
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def counting_post(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = counting_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
await svc.validate("test-invalidate-key12")
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
await svc.invalidate_cache("test-invalidate-key12")
|
||||||
|
|
||||||
|
await svc.validate("test-invalidate-key12")
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_cache(self):
|
||||||
|
svc = _make_service(cache_ttl=300.0)
|
||||||
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def counting_post(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = counting_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
await svc.validate("test-clear-key1-12345")
|
||||||
|
await svc.validate("test-clear-key2-12345")
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
await svc.clear_cache()
|
||||||
|
|
||||||
|
await svc.validate("test-clear-key1-12345")
|
||||||
|
await svc.validate("test-clear-key2-12345")
|
||||||
|
assert call_count == 4 # Both had to re-validate
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Transient failures & retries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransientFailures:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_5xx_not_cached(self):
|
||||||
|
svc = _make_service(cache_ttl=300.0)
|
||||||
|
mock_500 = _mock_response(500)
|
||||||
|
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
responses = [mock_500, mock_500, mock_ok] # Extra for retry
|
||||||
|
call_idx = 0
|
||||||
|
|
||||||
|
async def sequential_post(*args, **kwargs):
|
||||||
|
nonlocal call_idx
|
||||||
|
resp = responses[min(call_idx, len(responses) - 1)]
|
||||||
|
call_idx += 1
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = sequential_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
# First call: 500 -> not cached
|
||||||
|
r1 = await svc.validate("test-5xx-test-key1234")
|
||||||
|
assert r1.valid is False
|
||||||
|
assert r1.cacheable is False
|
||||||
|
|
||||||
|
# Second call should hit HTTP again (not cached)
|
||||||
|
r2 = await svc.validate("test-5xx-test-key1234")
|
||||||
|
# Second call also gets 500 from our mock sequence
|
||||||
|
assert r2.valid is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_then_retry_succeeds(self):
|
||||||
|
svc = _make_service()
|
||||||
|
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
async def timeout_then_ok(*args, **kwargs):
|
||||||
|
nonlocal attempt
|
||||||
|
attempt += 1
|
||||||
|
if attempt == 1:
|
||||||
|
raise httpx.TimeoutException("timed out")
|
||||||
|
return mock_ok
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = timeout_then_ok
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-timeout-retry-ok")
|
||||||
|
|
||||||
|
assert result.valid is True
|
||||||
|
assert result.user_id == "u1"
|
||||||
|
assert attempt == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_exhausts_retries(self):
|
||||||
|
svc = _make_service()
|
||||||
|
|
||||||
|
async def always_timeout(*args, **kwargs):
|
||||||
|
raise httpx.TimeoutException("timed out")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = always_timeout
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-timeout-exhaust1")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "timeout" in result.error.lower()
|
||||||
|
assert result.cacheable is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_error_then_retry_succeeds(self):
|
||||||
|
svc = _make_service()
|
||||||
|
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
async def error_then_ok(*args, **kwargs):
|
||||||
|
nonlocal attempt
|
||||||
|
attempt += 1
|
||||||
|
if attempt == 1:
|
||||||
|
raise httpx.ConnectError("connection refused")
|
||||||
|
return mock_ok
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = error_then_ok
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-reqerr-retry-ok1")
|
||||||
|
|
||||||
|
assert result.valid is True
|
||||||
|
assert attempt == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_error_exhausts_retries(self):
|
||||||
|
svc = _make_service()
|
||||||
|
|
||||||
|
async def always_error(*args, **kwargs):
|
||||||
|
raise httpx.ConnectError("connection refused")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = always_error
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-reqerr-exhaust1")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "unavailable" in result.error.lower()
|
||||||
|
assert result.cacheable is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unexpected_exception(self):
|
||||||
|
svc = _make_service()
|
||||||
|
|
||||||
|
async def unexpected(*args, **kwargs):
|
||||||
|
raise ValueError("something unexpected")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = unexpected
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
result = await svc.validate("test-unexpected-err12")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert result.cacheable is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Service token
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceToken:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_token_sent_in_headers(self):
|
||||||
|
svc = _make_service(
|
||||||
|
service_token_header="X-Service-Token",
|
||||||
|
service_token="test-svc-token-123",
|
||||||
|
)
|
||||||
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||||||
|
captured_headers = {}
|
||||||
|
|
||||||
|
async def capture_post(url, *, json=None, headers=None):
|
||||||
|
captured_headers.update(headers or {})
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
instance.post = capture_post
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
await svc.validate("test-svctoken-key1234")
|
||||||
|
|
||||||
|
assert captured_headers.get("X-Service-Token") == "test-svc-token-123"
|
||||||
|
assert captured_headers.get("Content-Type") == "application/json"
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""Tests for auth configuration validation and startup routes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _restore_config(monkeypatch):
|
||||||
|
"""Prevent main() side effects on the global config from leaking to other tests."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", config.http_remote_hosted)
|
||||||
|
monkeypatch.setattr(config, "api_key_validation_url", config.api_key_validation_url)
|
||||||
|
monkeypatch.setattr(config, "api_key_login_url", config.api_key_login_url)
|
||||||
|
monkeypatch.setattr(config, "api_key_cache_ttl", config.api_key_cache_ttl)
|
||||||
|
monkeypatch.setattr(config, "api_key_service_token_header", config.api_key_service_token_header)
|
||||||
|
monkeypatch.setattr(config, "api_key_service_token", config.api_key_service_token)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStartupConfigValidation:
|
||||||
|
def test_remote_hosted_flag_without_validation_url_exits(self, monkeypatch):
|
||||||
|
"""--http-remote-hosted without --api-key-validation-url should SystemExit(1)."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sys,
|
||||||
|
"argv",
|
||||||
|
[
|
||||||
|
"main",
|
||||||
|
"--transport", "http",
|
||||||
|
"--http-remote-hosted",
|
||||||
|
# Deliberately omit --api-key-validation-url
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.delenv("UNITY_MCP_API_KEY_VALIDATION_URL", raising=False)
|
||||||
|
monkeypatch.delenv("UNITY_MCP_HTTP_REMOTE_HOSTED", raising=False)
|
||||||
|
|
||||||
|
from main import main
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
main()
|
||||||
|
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
|
||||||
|
def test_remote_hosted_env_var_without_validation_url_exits(self, monkeypatch):
|
||||||
|
"""UNITY_MCP_HTTP_REMOTE_HOSTED=true without validation URL should SystemExit(1)."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sys,
|
||||||
|
"argv",
|
||||||
|
[
|
||||||
|
"main",
|
||||||
|
"--transport", "http",
|
||||||
|
# No --http-remote-hosted flag
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("UNITY_MCP_HTTP_REMOTE_HOSTED", "true")
|
||||||
|
monkeypatch.delenv("UNITY_MCP_API_KEY_VALIDATION_URL", raising=False)
|
||||||
|
|
||||||
|
from main import main
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
main()
|
||||||
|
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoginUrlEndpoint:
|
||||||
|
"""Test the /api/auth/login-url route handler logic.
|
||||||
|
|
||||||
|
These tests replicate the handler inline to avoid full MCP server construction.
|
||||||
|
The logic mirrors main.py's auth_login_url route exactly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _auth_login_url(_request):
|
||||||
|
"""Replicate the route handler from main.py."""
|
||||||
|
if not config.api_key_login_url:
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "API key management not configured. Contact your server administrator.",
|
||||||
|
},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"login_url": config.api_key_login_url,
|
||||||
|
})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_url_returns_url_when_configured(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "api_key_login_url",
|
||||||
|
"https://app.example.com/keys")
|
||||||
|
|
||||||
|
response = await self._auth_login_url(MagicMock(spec=Request))
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = json.loads(response.body.decode())
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["login_url"] == "https://app.example.com/keys"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_url_returns_404_when_not_configured(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "api_key_login_url", None)
|
||||||
|
|
||||||
|
response = await self._auth_login_url(MagicMock(spec=Request))
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
body = json.loads(response.body.decode())
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "not configured" in body["error"]
|
||||||
|
|
@ -27,7 +27,7 @@ async def test_plugin_hub_waits_for_reconnection_during_reload():
|
||||||
# Third call: session appears (plugin reconnected)
|
# Third call: session appears (plugin reconnected)
|
||||||
call_count = [0]
|
call_count = [0]
|
||||||
|
|
||||||
async def mock_list_sessions():
|
async def mock_list_sessions(**kwargs):
|
||||||
call_count[0] += 1
|
call_count[0] += 1
|
||||||
if call_count[0] <= 2:
|
if call_count[0] <= 2:
|
||||||
# Plugin not yet reconnected
|
# Plugin not yet reconnected
|
||||||
|
|
@ -77,7 +77,7 @@ async def test_plugin_hub_fails_after_timeout():
|
||||||
# Create a mock registry that never returns sessions
|
# Create a mock registry that never returns sessions
|
||||||
mock_registry = AsyncMock(spec=PluginRegistry)
|
mock_registry = AsyncMock(spec=PluginRegistry)
|
||||||
|
|
||||||
async def mock_list_sessions():
|
async def mock_list_sessions(**kwargs):
|
||||||
return {} # Never returns sessions
|
return {} # Never returns sessions
|
||||||
|
|
||||||
mock_registry.list_sessions = mock_list_sessions
|
mock_registry.list_sessions = mock_list_sessions
|
||||||
|
|
@ -161,7 +161,7 @@ async def test_read_console_during_simulated_reload(monkeypatch):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plugin_hub_respects_unity_instance_preference():
|
async def test_plugin_hub_respects_unity_instance_preference():
|
||||||
"""Test that _resolve_session_id prefers a specific Unity instance if requested."""
|
"""Test that _resolve_session_id prefers a specific Unity instance if requested."""
|
||||||
from transport.plugin_hub import PluginHub
|
from transport.plugin_hub import PluginHub, InstanceSelectionRequiredError
|
||||||
from transport.plugin_registry import PluginRegistry, PluginSession
|
from transport.plugin_registry import PluginRegistry, PluginSession
|
||||||
|
|
||||||
# Create a mock registry with two sessions
|
# Create a mock registry with two sessions
|
||||||
|
|
@ -185,19 +185,19 @@ async def test_plugin_hub_respects_unity_instance_preference():
|
||||||
connected_at=now
|
connected_at=now
|
||||||
)
|
)
|
||||||
|
|
||||||
async def mock_list_sessions():
|
async def mock_list_sessions(**kwargs):
|
||||||
return {
|
return {
|
||||||
"session-1": session1,
|
"session-1": session1,
|
||||||
"session-2": session2
|
"session-2": session2
|
||||||
}
|
}
|
||||||
|
|
||||||
async def mock_get_session_by_hash(project_hash):
|
async def mock_get_session_id_by_hash(project_hash, user_id=None):
|
||||||
if project_hash == "hash2":
|
if project_hash == "hash2":
|
||||||
return "session-2"
|
return "session-2"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mock_registry.list_sessions = mock_list_sessions
|
mock_registry.list_sessions = mock_list_sessions
|
||||||
mock_registry.get_session_id_by_hash = mock_get_session_by_hash
|
mock_registry.get_session_id_by_hash = mock_get_session_id_by_hash
|
||||||
|
|
||||||
# Configure PluginHub with our mock while preserving the original state
|
# Configure PluginHub with our mock while preserving the original state
|
||||||
original_registry = PluginHub._registry
|
original_registry = PluginHub._registry
|
||||||
|
|
@ -213,9 +213,8 @@ async def test_plugin_hub_respects_unity_instance_preference():
|
||||||
assert session_id == "session-2"
|
assert session_id == "session-2"
|
||||||
|
|
||||||
# Request default (no specific instance)
|
# Request default (no specific instance)
|
||||||
with pytest.raises(RuntimeError) as exc:
|
with pytest.raises(InstanceSelectionRequiredError, match="Multiple Unity instances"):
|
||||||
await PluginHub._resolve_session_id(unity_instance=None)
|
await PluginHub._resolve_session_id(unity_instance=None)
|
||||||
assert "Multiple Unity instances are connected" in str(exc.value)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up: restore original PluginHub state
|
# Clean up: restore original PluginHub state
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import types
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from .test_helpers import DummyContext
|
from .test_helpers import DummyContext
|
||||||
|
from core.config import config
|
||||||
|
|
||||||
|
|
||||||
class DummyMiddlewareContext:
|
class DummyMiddlewareContext:
|
||||||
|
|
@ -25,15 +26,12 @@ def test_auto_selects_single_instance_via_pluginhub(monkeypatch):
|
||||||
|
|
||||||
plugin_hub.PluginHub = PluginHub
|
plugin_hub.PluginHub = PluginHub
|
||||||
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
|
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
|
||||||
unity_transport = types.ModuleType("transport.unity_transport")
|
|
||||||
unity_transport._current_transport = lambda: "http"
|
|
||||||
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
|
|
||||||
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
|
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
|
||||||
|
|
||||||
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
|
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
|
||||||
assert ImportedPluginHub is plugin_hub.PluginHub
|
assert ImportedPluginHub is plugin_hub.PluginHub
|
||||||
|
|
||||||
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http")
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
|
||||||
middleware = UnityInstanceMiddleware()
|
middleware = UnityInstanceMiddleware()
|
||||||
ctx = DummyContext()
|
ctx = DummyContext()
|
||||||
|
|
@ -74,15 +72,12 @@ def test_auto_selects_single_instance_via_stdio(monkeypatch):
|
||||||
|
|
||||||
plugin_hub.PluginHub = PluginHub
|
plugin_hub.PluginHub = PluginHub
|
||||||
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
|
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
|
||||||
unity_transport = types.ModuleType("transport.unity_transport")
|
|
||||||
unity_transport._current_transport = lambda: "stdio"
|
|
||||||
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
|
|
||||||
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
|
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
|
||||||
|
|
||||||
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
|
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
|
||||||
assert ImportedPluginHub is plugin_hub.PluginHub
|
assert ImportedPluginHub is plugin_hub.PluginHub
|
||||||
|
|
||||||
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "stdio")
|
monkeypatch.setattr(config, "transport_mode", "stdio")
|
||||||
|
|
||||||
middleware = UnityInstanceMiddleware()
|
middleware = UnityInstanceMiddleware()
|
||||||
ctx = DummyContext()
|
ctx = DummyContext()
|
||||||
|
|
@ -118,9 +113,6 @@ def test_auto_select_handles_stdio_errors(monkeypatch):
|
||||||
|
|
||||||
plugin_hub.PluginHub = PluginHub
|
plugin_hub.PluginHub = PluginHub
|
||||||
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
|
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
|
||||||
unity_transport = types.ModuleType("transport.unity_transport")
|
|
||||||
unity_transport._current_transport = lambda: "stdio"
|
|
||||||
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
|
|
||||||
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
|
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)
|
||||||
|
|
||||||
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
|
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import pytest
|
||||||
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
||||||
from fastmcp import Context
|
from fastmcp import Context
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
from services.tools import get_unity_instance_from_context
|
from services.tools import get_unity_instance_from_context
|
||||||
from services.tools.set_active_instance import set_active_instance as set_active_instance_tool
|
from services.tools.set_active_instance import set_active_instance as set_active_instance_tool
|
||||||
|
|
@ -28,6 +29,7 @@ class TestInstanceRoutingBasics:
|
||||||
middleware = UnityInstanceMiddleware()
|
middleware = UnityInstanceMiddleware()
|
||||||
ctx = Mock(spec=Context)
|
ctx = Mock(spec=Context)
|
||||||
ctx.session_id = "test-session-1"
|
ctx.session_id = "test-session-1"
|
||||||
|
ctx.client_id = "test-client-1"
|
||||||
|
|
||||||
# Set active instance
|
# Set active instance
|
||||||
middleware.set_active_instance(ctx, "TestProject@abc123")
|
middleware.set_active_instance(ctx, "TestProject@abc123")
|
||||||
|
|
@ -73,6 +75,7 @@ class TestInstanceRoutingBasics:
|
||||||
ctx = Mock(spec=Context)
|
ctx = Mock(spec=Context)
|
||||||
ctx.session_id = None
|
ctx.session_id = None
|
||||||
ctx.client_id = None
|
ctx.client_id = None
|
||||||
|
ctx.get_state = Mock(return_value=None)
|
||||||
|
|
||||||
middleware.set_active_instance(ctx, "Project@global")
|
middleware.set_active_instance(ctx, "Project@global")
|
||||||
assert middleware.get_active_instance(ctx) == "Project@global"
|
assert middleware.get_active_instance(ctx) == "Project@global"
|
||||||
|
|
@ -170,7 +173,7 @@ class TestInstanceRoutingHTTP:
|
||||||
v: state_storage.__setitem__(k, v))
|
v: state_storage.__setitem__(k, v))
|
||||||
ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k))
|
ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k))
|
||||||
|
|
||||||
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http")
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
fake_sessions = SessionList(
|
fake_sessions = SessionList(
|
||||||
sessions={
|
sessions={
|
||||||
"sess-1": SessionDetails(
|
"sess-1": SessionDetails(
|
||||||
|
|
@ -206,7 +209,7 @@ class TestInstanceRoutingHTTP:
|
||||||
v: state_storage.__setitem__(k, v))
|
v: state_storage.__setitem__(k, v))
|
||||||
ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k))
|
ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k))
|
||||||
|
|
||||||
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http")
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
fake_sessions = SessionList(
|
fake_sessions = SessionList(
|
||||||
sessions={
|
sessions={
|
||||||
"sess-99": SessionDetails(
|
"sess-99": SessionDetails(
|
||||||
|
|
@ -238,7 +241,7 @@ class TestInstanceRoutingHTTP:
|
||||||
ctx = Mock(spec=Context)
|
ctx = Mock(spec=Context)
|
||||||
ctx.session_id = "http-session-3"
|
ctx.session_id = "http-session-3"
|
||||||
|
|
||||||
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http")
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
fake_sessions = SessionList(sessions={})
|
fake_sessions = SessionList(sessions={})
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"services.tools.set_active_instance.PluginHub.get_sessions",
|
"services.tools.set_active_instance.PluginHub.get_sessions",
|
||||||
|
|
@ -261,7 +264,7 @@ class TestInstanceRoutingHTTP:
|
||||||
ctx = Mock(spec=Context)
|
ctx = Mock(spec=Context)
|
||||||
ctx.session_id = "http-session-4"
|
ctx.session_id = "http-session-4"
|
||||||
|
|
||||||
monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http")
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
fake_sessions = SessionList(
|
fake_sessions = SessionList(
|
||||||
sessions={
|
sessions={
|
||||||
"sess-a": SessionDetails(project="ProjA", hash="abc12345", unity_version="2022", connected_at="now"),
|
"sess-a": SessionDetails(project="ProjA", hash="abc12345", unity_version="2022", connected_at="now"),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,172 @@
|
||||||
|
"""Tests for UnityInstanceMiddleware auth enforcement in remote-hosted mode."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from tests.integration.test_helpers import DummyContext
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareAuthEnforcement:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remote_hosted_requires_user_id(self, monkeypatch):
|
||||||
|
"""_inject_unity_instance should raise RuntimeError when remote-hosted and no user_id."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
|
||||||
|
# Mock _resolve_user_id to return None (no API key / failed validation)
|
||||||
|
monkeypatch.setattr(middleware, "_resolve_user_id",
|
||||||
|
AsyncMock(return_value=None))
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
middleware_ctx = Mock()
|
||||||
|
middleware_ctx.fastmcp_context = ctx
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="API key authentication required"):
|
||||||
|
await middleware._inject_unity_instance(middleware_ctx)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sets_user_id_in_context_state(self, monkeypatch):
|
||||||
|
"""_inject_unity_instance should set user_id in ctx state when resolved."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
monkeypatch.setattr(middleware, "_resolve_user_id",
|
||||||
|
AsyncMock(return_value="user-55"))
|
||||||
|
|
||||||
|
# We need PluginHub to be configured for the session resolution path
|
||||||
|
# But we don't need it to actually find a session for this test
|
||||||
|
from transport.plugin_hub import PluginHub
|
||||||
|
from transport.plugin_registry import PluginRegistry
|
||||||
|
|
||||||
|
registry = PluginRegistry()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
PluginHub.configure(registry, loop)
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
ctx.client_id = "client-1"
|
||||||
|
middleware_ctx = Mock()
|
||||||
|
middleware_ctx.fastmcp_context = ctx
|
||||||
|
|
||||||
|
# Set an active instance so the middleware doesn't try to auto-select
|
||||||
|
middleware.set_active_instance(ctx, "Proj@hash1")
|
||||||
|
# Register a matching session so resolution doesn't fail
|
||||||
|
await registry.register("s1", "Proj", "hash1", "2022", user_id="user-55")
|
||||||
|
|
||||||
|
await middleware._inject_unity_instance(middleware_ctx)
|
||||||
|
|
||||||
|
assert ctx.get_state("user_id") == "user-55"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareSessionKey:
|
||||||
|
def test_get_session_key_uses_user_id_fallback(self):
|
||||||
|
"""When no client_id, middleware should use user:$user_id as session key."""
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
# Simulate no client_id attribute
|
||||||
|
if hasattr(ctx, "client_id"):
|
||||||
|
delattr(ctx, "client_id")
|
||||||
|
ctx.set_state("user_id", "user-77")
|
||||||
|
|
||||||
|
key = middleware.get_session_key(ctx)
|
||||||
|
assert key == "user:user-77"
|
||||||
|
|
||||||
|
def test_get_session_key_prefers_client_id(self):
|
||||||
|
"""client_id should take precedence over user_id."""
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
ctx.client_id = "client-abc"
|
||||||
|
ctx.set_state("user_id", "user-77")
|
||||||
|
|
||||||
|
key = middleware.get_session_key(ctx)
|
||||||
|
assert key == "client-abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutoSelectDisabledRemoteHosted:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_select_returns_none_in_remote_hosted(self, monkeypatch):
|
||||||
|
"""_maybe_autoselect_instance should return None in remote-hosted mode even with one session."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
|
||||||
|
# Re-import middleware to pick up the stubbed transport module
|
||||||
|
monkeypatch.delitem(
|
||||||
|
sys.modules, "transport.unity_instance_middleware", raising=False)
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as HubRef
|
||||||
|
|
||||||
|
# Configure PluginHub with one session so auto-select has something to find
|
||||||
|
from transport.plugin_registry import PluginRegistry
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("s1", "Proj", "h1", "2022", user_id="userA")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
HubRef.configure(registry, loop)
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
ctx = DummyContext()
|
||||||
|
ctx.client_id = "client-1"
|
||||||
|
|
||||||
|
result = await middleware._maybe_autoselect_instance(ctx)
|
||||||
|
# Remote-hosted mode should NOT auto-select (early return at the transport check)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpAuthBehavior:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_local_does_not_require_user_id(self, monkeypatch):
|
||||||
|
"""HTTP local mode should allow requests without user_id."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", False)
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
|
||||||
|
from transport import unity_transport
|
||||||
|
|
||||||
|
async def fake_send_command_for_instance(*_args, **_kwargs):
|
||||||
|
return {"success": True, "data": {"ok": True}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
unity_transport.PluginHub,
|
||||||
|
"send_command_for_instance",
|
||||||
|
fake_send_command_for_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _unused_send_fn(*_args, **_kwargs):
|
||||||
|
raise AssertionError("send_fn should not be used in HTTP mode")
|
||||||
|
|
||||||
|
result = await unity_transport.send_with_unity_instance(
|
||||||
|
_unused_send_fn, None, "ping", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["data"] == {"ok": True}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_remote_requires_user_id(self, monkeypatch):
|
||||||
|
"""HTTP remote-hosted mode should reject requests without user_id."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
|
||||||
|
from transport import unity_transport
|
||||||
|
|
||||||
|
async def _unused_send_fn(*_args, **_kwargs):
|
||||||
|
raise AssertionError("send_fn should not be used in HTTP mode")
|
||||||
|
|
||||||
|
result = await unity_transport.send_with_unity_instance(
|
||||||
|
_unused_send_fn, None, "ping", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["error"] == "auth_required"
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
"""Integration tests for multi-user session isolation in remote-hosted mode.
|
||||||
|
|
||||||
|
These tests compose PluginRegistry + PluginHub to verify that users
|
||||||
|
cannot see or interact with each other's Unity instances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from transport.plugin_hub import NoUnitySessionError, PluginHub
|
||||||
|
from transport.plugin_registry import PluginRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_plugin_hub():
|
||||||
|
old_registry = PluginHub._registry
|
||||||
|
old_connections = PluginHub._connections.copy()
|
||||||
|
old_pending = PluginHub._pending.copy()
|
||||||
|
old_lock = PluginHub._lock
|
||||||
|
old_loop = PluginHub._loop
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
PluginHub._registry = old_registry
|
||||||
|
PluginHub._connections = old_connections
|
||||||
|
PluginHub._pending = old_pending
|
||||||
|
PluginHub._lock = old_lock
|
||||||
|
PluginHub._loop = old_loop
|
||||||
|
|
||||||
|
|
||||||
|
async def _setup_two_user_registry():
|
||||||
|
"""Set up a registry with two users, each having Unity instances.
|
||||||
|
|
||||||
|
Returns the configured registry. Also configures PluginHub to use it.
|
||||||
|
"""
|
||||||
|
registry = PluginRegistry()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
PluginHub.configure(registry, loop)
|
||||||
|
|
||||||
|
await registry.register("sess-A1", "ProjectAlpha", "hashA1", "2022.3", user_id="userA")
|
||||||
|
await registry.register("sess-B1", "ProjectBeta", "hashB1", "2022.3", user_id="userB")
|
||||||
|
await registry.register("sess-A2", "ProjectGamma", "hashA2", "2022.3", user_id="userA")
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiUserSessionFiltering:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sessions_filters_by_user(self):
|
||||||
|
"""PluginHub.get_sessions(user_id=X) returns only X's sessions."""
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
sessions_a = await PluginHub.get_sessions(user_id="userA")
|
||||||
|
assert len(sessions_a.sessions) == 2
|
||||||
|
project_names = {s.project for s in sessions_a.sessions.values()}
|
||||||
|
assert project_names == {"ProjectAlpha", "ProjectGamma"}
|
||||||
|
|
||||||
|
sessions_b = await PluginHub.get_sessions(user_id="userB")
|
||||||
|
assert len(sessions_b.sessions) == 1
|
||||||
|
assert next(iter(sessions_b.sessions.values())
|
||||||
|
).project == "ProjectBeta"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sessions_no_filter_returns_all_in_local_mode(self):
|
||||||
|
"""In local mode, PluginHub.get_sessions() without user_id returns everything."""
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
all_sessions = await PluginHub.get_sessions()
|
||||||
|
assert len(all_sessions.sessions) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sessions_no_filter_raises_in_remote_hosted(self, monkeypatch):
|
||||||
|
"""In remote-hosted mode, PluginHub.get_sessions() without user_id raises."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="requires user_id"):
|
||||||
|
await PluginHub.get_sessions()
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveSessionIdIsolation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_session_for_own_hash(self, monkeypatch):
|
||||||
|
"""User A can resolve their own project hash."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
session_id = await PluginHub._resolve_session_id("hashA1", user_id="userA")
|
||||||
|
assert session_id == "sess-A1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_resolve_other_users_hash(self, monkeypatch):
|
||||||
|
"""User A cannot resolve User B's project hash."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
monkeypatch.setenv("UNITY_MCP_SESSION_RESOLVE_MAX_WAIT_S", "0.1")
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
# userA tries to resolve userB's hash -> should not find it
|
||||||
|
with pytest.raises(NoUnitySessionError):
|
||||||
|
await PluginHub._resolve_session_id("hashB1", user_id="userA")
|
||||||
|
|
||||||
|
|
||||||
|
class TestInstanceListResourceIsolation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unity_instances_resource_filters_by_user(self, monkeypatch):
|
||||||
|
"""The unity_instances resource should pass user_id and return filtered results."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
from services.resources.unity_instances import unity_instances
|
||||||
|
from tests.integration.test_helpers import DummyContext
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
ctx.set_state("user_id", "userA")
|
||||||
|
|
||||||
|
result = await unity_instances(ctx)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["instance_count"] == 2
|
||||||
|
instance_names = {i["name"] for i in result["instances"]}
|
||||||
|
assert instance_names == {"ProjectAlpha", "ProjectGamma"}
|
||||||
|
assert "ProjectBeta" not in instance_names
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetActiveInstanceIsolation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_active_instance_only_sees_own_sessions(self, monkeypatch):
|
||||||
|
"""set_active_instance should only offer sessions belonging to the current user."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
from services.tools.set_active_instance import set_active_instance
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
|
from tests.integration.test_helpers import DummyContext
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.tools.set_active_instance.get_unity_instance_middleware",
|
||||||
|
lambda: middleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
ctx.set_state("user_id", "userA")
|
||||||
|
|
||||||
|
result = await set_active_instance(ctx, "ProjectAlpha@hashA1")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert middleware.get_active_instance(ctx) == "ProjectAlpha@hashA1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_active_instance_rejects_other_users_instance(self, monkeypatch):
|
||||||
|
"""set_active_instance should not find another user's instance."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
await _setup_two_user_registry()
|
||||||
|
|
||||||
|
from services.tools.set_active_instance import set_active_instance
|
||||||
|
from transport.unity_instance_middleware import UnityInstanceMiddleware
|
||||||
|
from tests.integration.test_helpers import DummyContext
|
||||||
|
|
||||||
|
middleware = UnityInstanceMiddleware()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.tools.set_active_instance.get_unity_instance_middleware",
|
||||||
|
lambda: middleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = DummyContext()
|
||||||
|
ctx.set_state("user_id", "userA")
|
||||||
|
|
||||||
|
# UserA tries to select UserB's instance -> should fail
|
||||||
|
result = await set_active_instance(ctx, "ProjectBeta@hashB1")
|
||||||
|
assert result["success"] is False
|
||||||
|
|
@ -0,0 +1,183 @@
|
||||||
|
"""Tests for PluginHub WebSocket API key authentication gate."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from core.constants import API_KEY_HEADER
|
||||||
|
from services.api_key_service import ApiKeyService, ValidationResult
|
||||||
|
from transport.plugin_hub import PluginHub
|
||||||
|
from transport.plugin_registry import PluginRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_api_key_singleton():
|
||||||
|
ApiKeyService._instance = None
|
||||||
|
yield
|
||||||
|
ApiKeyService._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_plugin_hub():
|
||||||
|
"""Ensure PluginHub class-level state doesn't leak between tests."""
|
||||||
|
old_registry = PluginHub._registry
|
||||||
|
old_connections = PluginHub._connections.copy()
|
||||||
|
old_pending = PluginHub._pending.copy()
|
||||||
|
old_lock = PluginHub._lock
|
||||||
|
old_loop = PluginHub._loop
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
PluginHub._registry = old_registry
|
||||||
|
PluginHub._connections = old_connections
|
||||||
|
PluginHub._pending = old_pending
|
||||||
|
PluginHub._lock = old_lock
|
||||||
|
PluginHub._loop = old_loop
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_websocket(headers=None, state_attrs=None):
|
||||||
|
"""Create a mock WebSocket with configurable headers and state."""
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.headers = headers or {}
|
||||||
|
ws.state = SimpleNamespace(**(state_attrs or {}))
|
||||||
|
ws.accept = AsyncMock()
|
||||||
|
ws.close = AsyncMock()
|
||||||
|
ws.send_json = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hub():
|
||||||
|
"""Create a PluginHub instance with a minimal ASGI scope."""
|
||||||
|
scope = {"type": "websocket"}
|
||||||
|
return PluginHub(scope, receive=AsyncMock(), send=AsyncMock())
|
||||||
|
|
||||||
|
|
||||||
|
def _init_api_key_service(validate_result=None):
|
||||||
|
"""Initialize ApiKeyService with a mocked validate method."""
|
||||||
|
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
|
||||||
|
if validate_result is not None:
|
||||||
|
svc.validate = AsyncMock(return_value=validate_result)
|
||||||
|
return svc
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSocketAuthGate:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_api_key_remote_hosted_rejected(self, monkeypatch):
|
||||||
|
"""WebSocket without API key in remote-hosted mode -> close 4401."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
_init_api_key_service(ValidationResult(valid=True, user_id="u1"))
|
||||||
|
|
||||||
|
ws = _make_mock_websocket(headers={}) # No X-API-Key header
|
||||||
|
hub = _make_hub()
|
||||||
|
|
||||||
|
await hub.on_connect(ws)
|
||||||
|
|
||||||
|
ws.close.assert_called_once_with(code=4401, reason="API key required")
|
||||||
|
ws.accept.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_api_key_rejected(self, monkeypatch):
|
||||||
|
"""WebSocket with invalid API key -> close 4403."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
_init_api_key_service(ValidationResult(
|
||||||
|
valid=False, error="Invalid API key"))
|
||||||
|
|
||||||
|
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-bad-key"})
|
||||||
|
hub = _make_hub()
|
||||||
|
|
||||||
|
await hub.on_connect(ws)
|
||||||
|
|
||||||
|
ws.close.assert_called_once_with(code=4403, reason="Invalid API key")
|
||||||
|
ws.accept.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_api_key_accepted(self, monkeypatch):
|
||||||
|
"""WebSocket with valid API key -> accepted, user_id stored in state."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
_init_api_key_service(
|
||||||
|
ValidationResult(valid=True, user_id="user-42",
|
||||||
|
metadata={"plan": "pro"})
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-valid-key"})
|
||||||
|
hub = _make_hub()
|
||||||
|
|
||||||
|
await hub.on_connect(ws)
|
||||||
|
|
||||||
|
ws.accept.assert_called_once()
|
||||||
|
ws.close.assert_not_called()
|
||||||
|
assert ws.state.user_id == "user-42"
|
||||||
|
assert ws.state.api_key_metadata == {"plan": "pro"}
|
||||||
|
# Should have sent welcome message
|
||||||
|
ws.send_json.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_service_unavailable_close_1013(self, monkeypatch):
|
||||||
|
"""Auth service error with 'unavailable' -> close 1013 (try again later)."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
_init_api_key_service(
|
||||||
|
ValidationResult(
|
||||||
|
valid=False, error="Auth service unavailable", cacheable=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-some-key"})
|
||||||
|
hub = _make_hub()
|
||||||
|
|
||||||
|
await hub.on_connect(ws)
|
||||||
|
|
||||||
|
ws.close.assert_called_once_with(code=1013, reason="Try again later")
|
||||||
|
ws.accept.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_remote_hosted_accepts_without_key(self, monkeypatch):
|
||||||
|
"""When not remote-hosted, WebSocket accepted without API key."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", False)
|
||||||
|
|
||||||
|
ws = _make_mock_websocket(headers={})
|
||||||
|
hub = _make_hub()
|
||||||
|
|
||||||
|
await hub.on_connect(ws)
|
||||||
|
|
||||||
|
ws.accept.assert_called_once()
|
||||||
|
ws.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserIdFlowsToRegistration:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_id_passed_to_registry_on_register(self, monkeypatch):
|
||||||
|
"""After valid auth, the register message should pass user_id to registry."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
_init_api_key_service(
|
||||||
|
ValidationResult(valid=True, user_id="user-99")
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = PluginRegistry()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
PluginHub.configure(registry, loop)
|
||||||
|
|
||||||
|
# Simulate full flow: connect, then register
|
||||||
|
ws = _make_mock_websocket(headers={API_KEY_HEADER: "sk-valid-key"})
|
||||||
|
hub = _make_hub()
|
||||||
|
|
||||||
|
await hub.on_connect(ws)
|
||||||
|
assert ws.state.user_id == "user-99"
|
||||||
|
|
||||||
|
# Simulate register message
|
||||||
|
register_data = {
|
||||||
|
"type": "register",
|
||||||
|
"project_name": "TestProject",
|
||||||
|
"project_hash": "abc123",
|
||||||
|
"unity_version": "2022.3",
|
||||||
|
}
|
||||||
|
await hub.on_receive(ws, register_data)
|
||||||
|
|
||||||
|
# Verify registry has the user_id
|
||||||
|
sessions = await registry.list_sessions(user_id="user-99")
|
||||||
|
assert len(sessions) == 1
|
||||||
|
session = next(iter(sessions.values()))
|
||||||
|
assert session.user_id == "user-99"
|
||||||
|
assert session.project_name == "TestProject"
|
||||||
|
assert session.project_hash == "abc123"
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
"""Tests for PluginRegistry user-scoped session isolation (remote-hosted mode)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from transport.plugin_registry import PluginRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegistryUserIsolation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_with_user_id_stores_composite_key(self):
|
||||||
|
registry = PluginRegistry()
|
||||||
|
session = await registry.register(
|
||||||
|
"sess-1", "MyProject", "hash1", "2022.3", user_id="user-A"
|
||||||
|
)
|
||||||
|
assert session.user_id == "user-A"
|
||||||
|
assert ("user-A", "hash1") in registry._user_hash_to_session
|
||||||
|
assert registry._user_hash_to_session[("user-A", "hash1")] == "sess-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_id_by_hash(self):
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("sess-1", "Proj", "h1", "2022", user_id="uA")
|
||||||
|
|
||||||
|
found = await registry.get_session_id_by_hash("h1", "uA")
|
||||||
|
assert found == "sess-1"
|
||||||
|
|
||||||
|
# Different user, same hash -> not found
|
||||||
|
not_found = await registry.get_session_id_by_hash("h1", "uB")
|
||||||
|
assert not_found is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cross_user_isolation_same_hash(self):
|
||||||
|
"""Two users registering with the same project_hash get independent sessions."""
|
||||||
|
registry = PluginRegistry()
|
||||||
|
sess_a = await registry.register("sA", "Proj", "hash1", "2022", user_id="userA")
|
||||||
|
sess_b = await registry.register("sB", "Proj", "hash1", "2022", user_id="userB")
|
||||||
|
|
||||||
|
assert sess_a.session_id == "sA"
|
||||||
|
assert sess_b.session_id == "sB"
|
||||||
|
|
||||||
|
# Each user resolves to their own session
|
||||||
|
assert await registry.get_session_id_by_hash("hash1", "userA") == "sA"
|
||||||
|
assert await registry.get_session_id_by_hash("hash1", "userB") == "sB"
|
||||||
|
|
||||||
|
# Both sessions exist
|
||||||
|
all_sessions = await registry.list_sessions()
|
||||||
|
assert len(all_sessions) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sessions_filtered_by_user(self):
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("s1", "ProjA", "hA", "2022", user_id="userA")
|
||||||
|
await registry.register("s2", "ProjB", "hB", "2022", user_id="userB")
|
||||||
|
await registry.register("s3", "ProjC", "hC", "2022", user_id="userA")
|
||||||
|
|
||||||
|
user_a_sessions = await registry.list_sessions(user_id="userA")
|
||||||
|
assert len(user_a_sessions) == 2
|
||||||
|
assert "s1" in user_a_sessions
|
||||||
|
assert "s3" in user_a_sessions
|
||||||
|
|
||||||
|
user_b_sessions = await registry.list_sessions(user_id="userB")
|
||||||
|
assert len(user_b_sessions) == 1
|
||||||
|
assert "s2" in user_b_sessions
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sessions_no_filter_returns_all_in_local_mode(self):
|
||||||
|
"""In local mode (not remote-hosted), list_sessions(user_id=None) returns all."""
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("s1", "P1", "h1", "2022", user_id="uA")
|
||||||
|
await registry.register("s2", "P2", "h2", "2022", user_id="uB")
|
||||||
|
await registry.register("s3", "P3", "h3", "2022") # local mode, no user_id
|
||||||
|
|
||||||
|
all_sessions = await registry.list_sessions(user_id=None)
|
||||||
|
assert len(all_sessions) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sessions_no_filter_raises_in_remote_hosted(self, monkeypatch):
|
||||||
|
"""In remote-hosted mode, list_sessions(user_id=None) raises ValueError."""
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("s1", "P1", "h1", "2022", user_id="uA")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="requires user_id"):
|
||||||
|
await registry.list_sessions(user_id=None)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_cleans_user_scoped_mapping(self):
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("s1", "Proj", "h1", "2022", user_id="uA")
|
||||||
|
assert ("uA", "h1") in registry._user_hash_to_session
|
||||||
|
|
||||||
|
await registry.unregister("s1")
|
||||||
|
|
||||||
|
assert ("uA", "h1") not in registry._user_hash_to_session
|
||||||
|
assert "s1" not in (await registry.list_sessions())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_replaces_previous_session(self):
|
||||||
|
"""Same (user_id, hash) re-registered evicts old session, stores new one."""
|
||||||
|
registry = PluginRegistry()
|
||||||
|
await registry.register("old-sess", "Proj", "h1", "2022", user_id="uA")
|
||||||
|
assert await registry.get_session_id_by_hash("h1", "uA") == "old-sess"
|
||||||
|
|
||||||
|
await registry.register("new-sess", "Proj", "h1", "2022", user_id="uA")
|
||||||
|
assert await registry.get_session_id_by_hash("h1", "uA") == "new-sess"
|
||||||
|
|
||||||
|
# Old session should be evicted
|
||||||
|
all_sessions = await registry.list_sessions()
|
||||||
|
assert "old-sess" not in all_sessions
|
||||||
|
assert "new-sess" in all_sessions
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""Tests for _resolve_user_id_from_request in unity_transport.py."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from services.api_key_service import ApiKeyService, ValidationResult
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_api_key_singleton():
|
||||||
|
ApiKeyService._instance = None
|
||||||
|
yield
|
||||||
|
ApiKeyService._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveUserIdFromRequest:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_not_remote_hosted(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", False)
|
||||||
|
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
|
||||||
|
result = await _resolve_user_id_from_request()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_service_not_initialized(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
# ApiKeyService._instance is None (from fixture)
|
||||||
|
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
|
||||||
|
result = await _resolve_user_id_from_request()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_user_id_for_valid_key(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
|
||||||
|
svc.validate = AsyncMock(
|
||||||
|
return_value=ValidationResult(valid=True, user_id="user-123")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stub the fastmcp dependency that provides HTTP headers
|
||||||
|
deps_mod = types.ModuleType("fastmcp.server.dependencies")
|
||||||
|
deps_mod.get_http_headers = lambda include_all=False: {
|
||||||
|
"x-api-key": "sk-valid"}
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules, "fastmcp.server.dependencies", deps_mod)
|
||||||
|
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
|
||||||
|
result = await _resolve_user_id_from_request()
|
||||||
|
assert result == "user-123"
|
||||||
|
svc.validate.assert_called_once_with("sk-valid")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_for_invalid_key(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
|
||||||
|
svc.validate = AsyncMock(
|
||||||
|
return_value=ValidationResult(valid=False, error="bad key")
|
||||||
|
)
|
||||||
|
|
||||||
|
deps_mod = types.ModuleType("fastmcp.server.dependencies")
|
||||||
|
deps_mod.get_http_headers = lambda include_all=False: {
|
||||||
|
"x-api-key": "sk-bad"}
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules, "fastmcp.server.dependencies", deps_mod)
|
||||||
|
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
|
||||||
|
result = await _resolve_user_id_from_request()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_exception(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
svc = ApiKeyService(validation_url="https://auth.example.com/validate")
|
||||||
|
svc.validate = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
|
||||||
|
deps_mod = types.ModuleType("fastmcp.server.dependencies")
|
||||||
|
deps_mod.get_http_headers = lambda include_all=False: {
|
||||||
|
"x-api-key": "sk-err"}
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules, "fastmcp.server.dependencies", deps_mod)
|
||||||
|
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
|
||||||
|
result = await _resolve_user_id_from_request()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_no_api_key_header(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
ApiKeyService(validation_url="https://auth.example.com/validate")
|
||||||
|
|
||||||
|
deps_mod = types.ModuleType("fastmcp.server.dependencies")
|
||||||
|
deps_mod.get_http_headers = lambda include_all=False: {} # No x-api-key
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules, "fastmcp.server.dependencies", deps_mod)
|
||||||
|
|
||||||
|
from transport.unity_transport import _resolve_user_id_from_request
|
||||||
|
|
||||||
|
result = await _resolve_user_id_from_request()
|
||||||
|
assert result is None
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import types
|
import types
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
@ -7,7 +8,13 @@ import core.telemetry as telemetry
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
|
def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
|
||||||
caplog.set_level("DEBUG")
|
# Directly attach caplog's handler to the telemetry logger so that
|
||||||
|
# earlier tests calling logging.basicConfig() can't steal the records
|
||||||
|
# via a root handler before caplog sees them.
|
||||||
|
tel_logger = logging.getLogger("unity-mcp-telemetry")
|
||||||
|
tel_logger.addHandler(caplog.handler)
|
||||||
|
try:
|
||||||
|
caplog.set_level("DEBUG", logger="unity-mcp-telemetry")
|
||||||
|
|
||||||
collector = telemetry.TelemetryCollector()
|
collector = telemetry.TelemetryCollector()
|
||||||
# Force-enable telemetry regardless of env settings from conftest
|
# Force-enable telemetry regardless of env settings from conftest
|
||||||
|
|
@ -18,8 +25,9 @@ def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
|
||||||
# Replace queue with tiny one to trigger backpressure quickly
|
# Replace queue with tiny one to trigger backpressure quickly
|
||||||
small_q = q.Queue(maxsize=2)
|
small_q = q.Queue(maxsize=2)
|
||||||
collector._queue = small_q
|
collector._queue = small_q
|
||||||
# Give the worker a moment to switch queues
|
# Give the worker time to finish processing the seeded item and
|
||||||
time.sleep(0.02)
|
# re-enter _queue.get() on the new small queue
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
# Make sends slow to build backlog and exercise worker
|
# Make sends slow to build backlog and exercise worker
|
||||||
def slow_send(self, rec):
|
def slow_send(self, rec):
|
||||||
|
|
@ -52,3 +60,6 @@ def test_telemetry_queue_backpressure_and_single_worker(monkeypatch, caplog):
|
||||||
worker_threads = [
|
worker_threads = [
|
||||||
t for t in threading.enumerate() if t is collector._worker]
|
t for t in threading.enumerate() if t is collector._worker]
|
||||||
assert len(worker_threads) == 1
|
assert len(worker_threads) == 1
|
||||||
|
finally:
|
||||||
|
if caplog.handler in tel_logger.handlers:
|
||||||
|
tel_logger.removeHandler(caplog.handler)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
"""End-to-end-ish smoke tests for transport routing paths."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
from transport import unity_transport
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_local_smoke(monkeypatch):
|
||||||
|
"""HTTP local should route through PluginHub without requiring user_id."""
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", False)
|
||||||
|
|
||||||
|
async def fake_send_command_for_instance(_instance, _command, _params, **_kwargs):
|
||||||
|
return {"status": "success", "result": {"message": "ok", "data": {"via": "http"}}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
unity_transport.PluginHub,
|
||||||
|
"send_command_for_instance",
|
||||||
|
fake_send_command_for_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _unused_send_fn(*_args, **_kwargs):
|
||||||
|
raise AssertionError("send_fn should not be used in HTTP mode")
|
||||||
|
|
||||||
|
result = await unity_transport.send_with_unity_instance(
|
||||||
|
_unused_send_fn, None, "ping", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["data"] == {"via": "http"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_remote_smoke(monkeypatch):
|
||||||
|
"""HTTP remote-hosted should route through PluginHub when user_id is provided."""
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "http")
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", True)
|
||||||
|
|
||||||
|
async def fake_send_command_for_instance(_instance, _command, _params, **_kwargs):
|
||||||
|
return {"status": "success", "result": {"data": {"via": "http-remote"}}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
unity_transport.PluginHub,
|
||||||
|
"send_command_for_instance",
|
||||||
|
fake_send_command_for_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _unused_send_fn(*_args, **_kwargs):
|
||||||
|
raise AssertionError("send_fn should not be used in HTTP mode")
|
||||||
|
|
||||||
|
result = await unity_transport.send_with_unity_instance(
|
||||||
|
_unused_send_fn, None, "ping", {}, user_id="user-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["data"] == {"via": "http-remote"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stdio_smoke(monkeypatch):
|
||||||
|
"""Stdio transport should call the legacy send fn with instance_id."""
|
||||||
|
monkeypatch.setattr(config, "transport_mode", "stdio")
|
||||||
|
monkeypatch.setattr(config, "http_remote_hosted", False)
|
||||||
|
|
||||||
|
async def fake_send_fn(command_type, params, *, instance_id=None, **_kwargs):
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {"via": "stdio", "command": command_type, "instance": instance_id, "params": params},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await unity_transport.send_with_unity_instance(
|
||||||
|
fake_send_fn, "Project@abcd1234", "ping", {"x": 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["data"]["via"] == "stdio"
|
||||||
|
assert result["data"]["instance"] == "Project@abcd1234"
|
||||||
|
|
@ -65,21 +65,6 @@ def mock_instances_response():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_sessions_response():
|
|
||||||
"""Mock plugin sessions response (legacy format)."""
|
|
||||||
return {
|
|
||||||
"sessions": {
|
|
||||||
"test-session-123": {
|
|
||||||
"project": "TestProject",
|
|
||||||
"hash": "abc123def456",
|
|
||||||
"unity_version": "2022.3.10f1",
|
|
||||||
"connected_at": "2024-01-01T00:00:00Z",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Config Tests
|
# Config Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -246,23 +231,6 @@ class TestConnection:
|
||||||
with pytest.raises(UnityConnectionError):
|
with pytest.raises(UnityConnectionError):
|
||||||
await send_command("test_command", {})
|
await send_command("test_command", {})
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_instances_from_sessions(self, mock_sessions_response):
|
|
||||||
"""Test listing instances from /plugin/sessions endpoint."""
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json.return_value = mock_sessions_response
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
|
||||||
# First call (api/instances) returns 404, second (plugin/sessions) succeeds
|
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
|
||||||
mock_client.return_value.__aenter__.return_value.get = mock_get
|
|
||||||
|
|
||||||
result = await list_unity_instances()
|
|
||||||
assert result["success"] is True
|
|
||||||
assert len(result["instances"]) == 1
|
|
||||||
assert result["instances"][0]["project"] == "TestProject"
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# CLI Command Tests
|
# CLI Command Tests
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ import uuid
|
||||||
|
|
||||||
from transport.unity_instance_middleware import UnityInstanceMiddleware, get_unity_instance_middleware, set_unity_instance_middleware
|
from transport.unity_instance_middleware import UnityInstanceMiddleware, get_unity_instance_middleware, set_unity_instance_middleware
|
||||||
from transport.plugin_registry import PluginRegistry, PluginSession
|
from transport.plugin_registry import PluginRegistry, PluginSession
|
||||||
from transport.plugin_hub import PluginHub, NoUnitySessionError, PluginDisconnectedError
|
from transport.plugin_hub import PluginHub, NoUnitySessionError, InstanceSelectionRequiredError, PluginDisconnectedError
|
||||||
from transport.models import (
|
from transport.models import (
|
||||||
RegisterMessage,
|
RegisterMessage,
|
||||||
RegisterToolsMessage,
|
RegisterToolsMessage,
|
||||||
|
|
@ -775,7 +775,7 @@ class TestSessionResolution:
|
||||||
unity_version="2023.2"
|
unity_version="2023.2"
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Multiple Unity instances"):
|
with pytest.raises(InstanceSelectionRequiredError, match="Multiple Unity instances"):
|
||||||
await PluginHub._resolve_session_id(None)
|
await PluginHub._resolve_session_id(None)
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,261 @@
|
||||||
|
# Remote Server API Key Authentication
|
||||||
|
|
||||||
|
When running the MCP for Unity server as a shared remote service, API key authentication ensures that only authorized users can access the server and that each user's Unity sessions are isolated from one another.
|
||||||
|
|
||||||
|
This guide covers how to configure, deploy, and use the feature.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
### External Auth Service
|
||||||
|
|
||||||
|
You need an external HTTP endpoint that validates API keys. The server delegates all key validation to this endpoint rather than managing keys itself.
|
||||||
|
|
||||||
|
The endpoint must:
|
||||||
|
|
||||||
|
- Accept `POST` requests with a JSON body: `{"api_key": "<key>"}`
|
||||||
|
- Return a JSON response indicating validity and the associated user identity
|
||||||
|
- Be reachable from the MCP server over the network
|
||||||
|
|
||||||
|
See [Validation Contract](#validation-contract) for the full request/response specification.
|
||||||
|
|
||||||
|
### Transport Mode
|
||||||
|
|
||||||
|
API key authentication is only available when running with HTTP transport (`--transport http`). It has no effect in stdio mode.
|
||||||
|
|
||||||
|
## Server Configuration
|
||||||
|
|
||||||
|
### CLI Arguments
|
||||||
|
|
||||||
|
| Argument | Environment Variable | Default | Description |
|
||||||
|
| -------- | -------------------- | ------- | ----------- |
|
||||||
|
| `--http-remote-hosted` | `UNITY_MCP_HTTP_REMOTE_HOSTED` | `false` | Enable remote-hosted mode. Requires API key auth. |
|
||||||
|
| `--api-key-validation-url URL` | `UNITY_MCP_API_KEY_VALIDATION_URL` | None | External endpoint to validate API keys (required). |
|
||||||
|
| `--api-key-login-url URL` | `UNITY_MCP_API_KEY_LOGIN_URL` | None | URL where users can obtain or manage API keys. |
|
||||||
|
| `--api-key-cache-ttl SECONDS` | `UNITY_MCP_API_KEY_CACHE_TTL` | `300` | How long validated keys are cached (seconds). |
|
||||||
|
| `--api-key-service-token-header HEADER` | `UNITY_MCP_API_KEY_SERVICE_TOKEN_HEADER` | None | Header name for server-to-auth-service authentication. |
|
||||||
|
| `--api-key-service-token TOKEN` | `UNITY_MCP_API_KEY_SERVICE_TOKEN` | None | Token value sent to the auth service for server authentication. |
|
||||||
|
|
||||||
|
Environment variables take effect when the corresponding CLI argument is not provided. For boolean flags, set the env var to `true`, `1`, or `yes`.
|
||||||
|
|
||||||
|
### Startup Validation
|
||||||
|
|
||||||
|
The server validates its configuration at startup:
|
||||||
|
|
||||||
|
- If `--http-remote-hosted` is set but `--api-key-validation-url` is not provided (and the env var is also unset), the server logs an error and exits with code 1.
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.main \
|
||||||
|
--transport http \
|
||||||
|
--http-host 0.0.0.0 \
|
||||||
|
--http-port 8080 \
|
||||||
|
--http-remote-hosted \
|
||||||
|
--api-key-validation-url https://auth.example.com/api/validate-key \
|
||||||
|
--api-key-login-url https://app.example.com/api-keys \
|
||||||
|
--api-key-cache-ttl 120
|
||||||
|
```
|
||||||
|
|
||||||
|
Or using environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export UNITY_MCP_TRANSPORT=http
|
||||||
|
export UNITY_MCP_HTTP_HOST=0.0.0.0
|
||||||
|
export UNITY_MCP_HTTP_PORT=8080
|
||||||
|
export UNITY_MCP_HTTP_REMOTE_HOSTED=true
|
||||||
|
export UNITY_MCP_API_KEY_VALIDATION_URL=https://auth.example.com/api/validate-key
|
||||||
|
export UNITY_MCP_API_KEY_LOGIN_URL=https://app.example.com/api-keys
|
||||||
|
|
||||||
|
python -m src.main
|
||||||
|
```
|
||||||
|
|
||||||
|
### Service Token (Optional)
|
||||||
|
|
||||||
|
If your auth service requires the MCP server to authenticate itself (server-to-server auth), configure a service token:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--api-key-service-token-header X-Service-Token \
|
||||||
|
--api-key-service-token "your-server-secret"
|
||||||
|
```
|
||||||
|
|
||||||
|
This adds the specified header to every validation request sent to the auth endpoint.
|
||||||
|
|
||||||
|
We strongly recommend using this feature because it ensures that the entity requesting validation is the MCP server itself, not an imposter.
|
||||||
|
|
||||||
|
## Unity Plugin Setup
|
||||||
|
|
||||||
|
When connecting to a remote-hosted server, Unity users need to provide their API key:
|
||||||
|
|
||||||
|
1. Open the MCP for Unity window in the Unity Editor.
|
||||||
|
2. Select HTTP Remote as the connection mode.
|
||||||
|
3. Enter the API key in the API Key field. The key is stored in `EditorPrefs` (per-machine, not source-controlled).
|
||||||
|
4. Click **Get API Key** to open the login URL in a browser if you need a new key. This fetches the URL from the server's `/api/auth/login-url` endpoint.
|
||||||
|
|
||||||
|
The API key is a one-time entry per machine. It persists across Unity sessions until explicitly cleared.
|
||||||
|
|
||||||
|
## MCP Client Configuration
|
||||||
|
|
||||||
|
When an API key is configured, the Unity plugin's MCP client configurators automatically include the `X-API-Key` header in generated configuration files.
|
||||||
|
|
||||||
|
Example generated config for **Cursor** (`~/.cursor/mcp.json`):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"mcp-for-unity": {
|
||||||
|
"url": "http://remote-server:8080/mcp",
|
||||||
|
"headers": {
|
||||||
|
"X-API-Key": "<your-api-key>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example for **Claude Code** (CLI):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
claude mcp add --transport http mcp-for-unity http://remote-server:8080/mcp \
|
||||||
|
--header "X-API-Key: <your-api-key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
Similar header injection works for VS Code, Windsurf, Cline, and other supported MCP clients.
|
||||||
|
|
||||||
|
## Behaviour Changes in Remote-Hosted Mode
|
||||||
|
|
||||||
|
Enabling `--http-remote-hosted` changes several server behaviours compared to the default local mode:
|
||||||
|
|
||||||
|
### Authentication Enforcement
|
||||||
|
|
||||||
|
All MCP tool and resource calls require a valid API key. The `X-API-Key` header must be present on every HTTP request to the `/mcp` endpoint. If the key is missing or invalid, the middleware raises a `RuntimeError` that surfaces as an MCP error response.
|
||||||
|
|
||||||
|
### WebSocket Auth Gate
|
||||||
|
|
||||||
|
Unity plugins connecting via WebSocket (`/hub/plugin`) are validated during the handshake:
|
||||||
|
|
||||||
|
| Scenario | WebSocket Close Code | Reason |
|
||||||
|
| -------- | -------------------- | ------ |
|
||||||
|
| No API key header | `4401` | API key required |
|
||||||
|
| Invalid API key | `4403` | Invalid API key |
|
||||||
|
| Auth service unavailable | `1013` | Try again later |
|
||||||
|
| Valid API key | Connection accepted | user_id stored in connection state |
|
||||||
|
|
||||||
|
### Session Isolation
|
||||||
|
|
||||||
|
Each user can only see and interact with their own Unity instances. When User A calls `set_active_instance` or lists instances, they only see Unity editors that connected with User A's API key. User B's sessions are invisible to User A.
|
||||||
|
|
||||||
|
### Auto-Select Disabled
|
||||||
|
|
||||||
|
In local mode, the server automatically selects the sole connected Unity instance. In remote-hosted mode, this auto-selection is disabled. Users must explicitly call `set_active_instance` with a `Name@hash` from the `mcpforunity://instances` resource.
|
||||||
|
|
||||||
|
### CLI Routes Disabled
|
||||||
|
|
||||||
|
The following REST endpoints are disabled in remote-hosted mode to prevent unauthenticated access:
|
||||||
|
|
||||||
|
- `POST /api/command`
|
||||||
|
- `GET /api/instances`
|
||||||
|
- `GET /api/custom-tools`
|
||||||
|
|
||||||
|
### Endpoints Always Available
|
||||||
|
|
||||||
|
These endpoints remain accessible regardless of auth:
|
||||||
|
|
||||||
|
| Endpoint | Method | Purpose |
|
||||||
|
| -------- | ------ | ------- |
|
||||||
|
| `/health` | GET | Health check for load balancers and monitoring |
|
||||||
|
| `/api/auth/login-url` | GET | Returns the login URL for API key management |
|
||||||
|
|
||||||
|
## Validation Contract
|
||||||
|
|
||||||
|
### Request
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST <api-key-validation-url>
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"api_key": "<the-api-key>"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
If a service token is configured, an additional header is sent:
|
||||||
|
|
||||||
|
```http
|
||||||
|
<service-token-header>: <service-token-value>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Response (Valid Key)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"valid": true,
|
||||||
|
"user_id": "user-abc-123",
|
||||||
|
"metadata": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `valid` (bool, required): Must be `true`.
|
||||||
|
- `user_id` (string, required): Stable identifier for the user. Used for session isolation.
|
||||||
|
- `metadata` (object, optional): Arbitrary metadata stored alongside the validation result.
|
||||||
|
|
||||||
|
### Response (Invalid Key)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"valid": false,
|
||||||
|
"error": "API key expired"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `valid` (bool, required): Must be `false`.
|
||||||
|
- `error` (string, optional): Human-readable reason.
|
||||||
|
|
||||||
|
### Response (HTTP 401)
|
||||||
|
|
||||||
|
A `401` status code is also treated as an invalid key (no body parsing required).
|
||||||
|
|
||||||
|
### Timeouts and Retries
|
||||||
|
|
||||||
|
- Request timeout: 5 seconds
|
||||||
|
- Retries: 1 (with 100ms backoff)
|
||||||
|
- Failure mode: deny by default (treated as invalid on any error)
|
||||||
|
|
||||||
|
Transient failures (5xx, timeouts, network errors) are **not cached**, so subsequent requests will retry the auth service.
|
||||||
|
|
||||||
|
## Error Reference
|
||||||
|
|
||||||
|
| Context | Condition | Response |
|
||||||
|
| ------- | --------- | -------- |
|
||||||
|
| MCP tool/resource | Missing API key (remote-hosted) | `RuntimeError` → MCP `isError: true` |
|
||||||
|
| MCP tool/resource | Invalid API key | `RuntimeError` → MCP `isError: true` |
|
||||||
|
| WebSocket connect | Missing API key | Close `4401` "API key required" |
|
||||||
|
| WebSocket connect | Invalid API key | Close `4403` "Invalid API key" |
|
||||||
|
| WebSocket connect | Auth service down | Close `1013` "Try again later" |
|
||||||
|
| `/api/auth/login-url` | Login URL not configured | HTTP `404` with admin guidance message |
|
||||||
|
| Server startup | Remote-hosted without validation URL | `SystemExit(1)` |
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "API key authentication required" error on every tool call
|
||||||
|
|
||||||
|
The server is in remote-hosted mode but no API key is being sent. Ensure the MCP client configuration includes the `X-API-Key` header, or set it in the Unity plugin's connection settings.
|
||||||
|
|
||||||
|
### Server exits immediately with code 1
|
||||||
|
|
||||||
|
The `--http-remote-hosted` flag requires `--api-key-validation-url`. Provide the URL via CLI argument or `UNITY_MCP_API_KEY_VALIDATION_URL` environment variable.
|
||||||
|
|
||||||
|
### WebSocket connection closes with 4401
|
||||||
|
|
||||||
|
The Unity plugin is not sending an API key. Enter the key in the MCP for Unity window's connection settings.
|
||||||
|
|
||||||
|
### WebSocket connection closes with 1013
|
||||||
|
|
||||||
|
The external auth service is unreachable. Check network connectivity between the MCP server and the validation URL. The Unity plugin can retry the connection.
|
||||||
|
|
||||||
|
### User cannot see their Unity instance
|
||||||
|
|
||||||
|
Session isolation is active. The Unity editor and the MCP client must use API keys that resolve to the same `user_id`. Verify that the Unity plugin's WebSocket connection and the MCP client's HTTP requests use the same API key.
|
||||||
|
|
||||||
|
### Stale auth after key rotation
|
||||||
|
|
||||||
|
Validated keys are cached for `--api-key-cache-ttl` seconds (default: 300). After rotating or revoking a key, there is a delay equal to the TTL before the old key stops working. Lower the TTL for faster revocation at the cost of more frequent validation requests.
|
||||||
|
|
@ -0,0 +1,363 @@
|
||||||
|
# Remote Server Auth: Architecture
|
||||||
|
|
||||||
|
This document describes the internal design of the API key authentication system used when the MCP for Unity server runs in remote-hosted mode. It is intended for contributors and maintainers.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
```
|
||||||
|
MCP Client MCP Server External Auth
|
||||||
|
(Cursor, etc.) (Python) Service
|
||||||
|
| | |
|
||||||
|
| X-API-Key: abc123 | |
|
||||||
|
| POST /mcp (tool call) | |
|
||||||
|
|-------------------------->| |
|
||||||
|
| | |
|
||||||
|
| UnityInstanceMiddleware.on_call_tool |
|
||||||
|
| | |
|
||||||
|
| _resolve_user_id() |
|
||||||
|
| | |
|
||||||
|
| | POST /validate |
|
||||||
|
| | {"api_key": "abc123"} |
|
||||||
|
| |------------------------------>|
|
||||||
|
| | |
|
||||||
|
| | {"valid":true, |
|
||||||
|
| | "user_id":"user-42"} |
|
||||||
|
| |<------------------------------|
|
||||||
|
| | |
|
||||||
|
| Cache result (TTL) |
|
||||||
|
| | |
|
||||||
|
| ctx.set_state("user_id", "user-42") |
|
||||||
|
| ctx.set_state("unity_instance", "Proj@hash") |
|
||||||
|
| | |
|
||||||
|
| PluginHub.send_command_for_instance |
|
||||||
|
| (user_id scoped session lookup) |
|
||||||
|
| | |
|
||||||
|
| Tool result | |
|
||||||
|
|<--------------------------| |
|
||||||
|
|
||||||
|
|
||||||
|
Unity Plugin MCP Server External Auth
|
||||||
|
(C# WebSocket) (Python) Service
|
||||||
|
| | |
|
||||||
|
| WS /hub/plugin | |
|
||||||
|
| X-API-Key: abc123 | |
|
||||||
|
|-------------------------->| |
|
||||||
|
| | |
|
||||||
|
| PluginHub.on_connect |
|
||||||
|
| | POST /validate |
|
||||||
|
| |------------------------------>|
|
||||||
|
| | {"valid":true, ...} |
|
||||||
|
| |<------------------------------|
|
||||||
|
| | |
|
||||||
|
| accept() | |
|
||||||
|
| websocket.state.user_id = "user-42" |
|
||||||
|
|<--------------------------| |
|
||||||
|
| | |
|
||||||
|
| {"type":"register", ...} | |
|
||||||
|
|-------------------------->| |
|
||||||
|
| | |
|
||||||
|
| PluginRegistry.register( |
|
||||||
|
| ..., user_id="user-42") |
|
||||||
|
| _user_hash_to_session[("user-42","hash")] = sid |
|
||||||
|
| | |
|
||||||
|
| {"type":"registered"} | |
|
||||||
|
|<--------------------------| |
|
||||||
|
```
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### ApiKeyService
|
||||||
|
|
||||||
|
**File:** `Server/src/services/api_key_service.py`
|
||||||
|
|
||||||
|
Singleton service that validates API keys against an external HTTP endpoint.
|
||||||
|
|
||||||
|
- **Singleton access:** `ApiKeyService.get_instance()` / `ApiKeyService.is_initialized()`
|
||||||
|
- **Initialization:** Constructed in `create_mcp_server()` when `config.http_remote_hosted` and `config.api_key_validation_url` are both set.
|
||||||
|
- **Validation:** `async validate(api_key) -> ValidationResult`
|
||||||
|
- **Caching:** In-memory dict keyed by raw API key. Entries store `(valid, user_id, metadata, expires_at)`.
|
||||||
|
- **Retry:** 1 retry with 100ms backoff on timeouts and connection errors.
|
||||||
|
- **Fail-closed:** Any unrecoverable error returns `ValidationResult(valid=False)`.
|
||||||
|
|
||||||
|
### PluginHub (WebSocket Auth Gate)
|
||||||
|
|
||||||
|
**File:** `Server/src/transport/plugin_hub.py`
|
||||||
|
|
||||||
|
The `on_connect` method validates the API key from the WebSocket handshake headers before accepting the connection.
|
||||||
|
|
||||||
|
- Reads `X-API-Key` from `websocket.headers`
|
||||||
|
- Validates via `ApiKeyService.validate()`
|
||||||
|
- Stores `user_id` and `api_key_metadata` on `websocket.state` for use during registration
|
||||||
|
- Rejects with close codes: `4401` (missing), `4403` (invalid), `1013` (service unavailable)
|
||||||
|
|
||||||
|
The `_handle_register` method reads `websocket.state.user_id` and passes it to `PluginRegistry.register()`.
|
||||||
|
|
||||||
|
The `get_sessions(user_id=None)` and `_resolve_session_id(unity_instance, user_id=None)` methods accept an optional `user_id` to scope session queries in remote-hosted mode.
|
||||||
|
|
||||||
|
### PluginRegistry (Dual-Index Session Storage)
|
||||||
|
|
||||||
|
**File:** `Server/src/transport/plugin_registry.py`
|
||||||
|
|
||||||
|
In-memory registry of connected Unity plugin sessions. Maintains two parallel index maps:
|
||||||
|
|
||||||
|
| Index | Key | Used In |
|
||||||
|
|-------|-----|---------|
|
||||||
|
| `_hash_to_session` | `project_hash -> session_id` | Local mode |
|
||||||
|
| `_user_hash_to_session` | `(user_id, project_hash) -> session_id` | Remote-hosted mode |
|
||||||
|
|
||||||
|
Both indexes are updated during `register()` and cleaned up during `unregister()`.
|
||||||
|
|
||||||
|
Key methods:
|
||||||
|
|
||||||
|
- `register(session_id, project_name, project_hash, unity_version, user_id=None)` - Registers a session and updates the appropriate index. If an existing session claims the same key, it is evicted.
|
||||||
|
- `get_session_id_by_hash(project_hash)` - Local-mode lookup.
|
||||||
|
- `get_session_id_by_hash(project_hash, user_id)` - Remote-mode lookup.
|
||||||
|
- `list_sessions(user_id=None)` - Returns sessions filtered by user. Raises `ValueError` if `user_id` is `None` while `config.http_remote_hosted` is `True`, preventing accidental cross-user leaks.
|
||||||
|
|
||||||
|
### UnityInstanceMiddleware
|
||||||
|
|
||||||
|
**File:** `Server/src/transport/unity_instance_middleware.py`
|
||||||
|
|
||||||
|
FastMCP middleware that intercepts all tool and resource calls to inject the active Unity instance and user identity into the request-scoped context.
|
||||||
|
|
||||||
|
Entry points:
|
||||||
|
|
||||||
|
- `on_call_tool(context, call_next)` - Intercepts tool calls.
|
||||||
|
- `on_read_resource(context, call_next)` - Intercepts resource reads.
|
||||||
|
|
||||||
|
Both delegate to `_inject_unity_instance(context)`, which:
|
||||||
|
|
||||||
|
1. Calls `_resolve_user_id()` to extract the user identity from the HTTP request.
|
||||||
|
2. If remote-hosted mode is active and no `user_id` is resolved, raises `RuntimeError` (surfaces as MCP error).
|
||||||
|
3. Sets `ctx.set_state("user_id", user_id)`.
|
||||||
|
4. Looks up or auto-selects the active Unity instance.
|
||||||
|
5. Sets `ctx.set_state("unity_instance", active_instance)`.
|
||||||
|
|
||||||
|
### _resolve_user_id_from_request
|
||||||
|
|
||||||
|
**File:** `Server/src/transport/unity_transport.py`
|
||||||
|
|
||||||
|
Extracts the `user_id` from the current HTTP request's `X-API-Key` header.
|
||||||
|
|
||||||
|
```
|
||||||
|
_resolve_user_id_from_request()
|
||||||
|
-> if not config.http_remote_hosted: return None
|
||||||
|
-> if not ApiKeyService.is_initialized(): return None
|
||||||
|
-> get_http_headers() from FastMCP dependencies
|
||||||
|
-> extract "x-api-key" header
|
||||||
|
-> ApiKeyService.validate(api_key)
|
||||||
|
-> return result.user_id if valid, else None
|
||||||
|
```
|
||||||
|
|
||||||
|
The middleware calls this indirectly through `_resolve_user_id()`, which adds an early return when not in remote-hosted mode (avoiding the import of FastMCP internals in local mode).
|
||||||
|
|
||||||
|
## Request Lifecycle
|
||||||
|
|
||||||
|
A complete authenticated MCP tool call follows this path:
|
||||||
|
|
||||||
|
1. **HTTP request arrives** at `/mcp` with `X-API-Key: <key>` header.
|
||||||
|
|
||||||
|
2. **FastMCP dispatches** the MCP tool call through its middleware chain.
|
||||||
|
|
||||||
|
3. **`UnityInstanceMiddleware.on_call_tool`** is invoked.
|
||||||
|
|
||||||
|
4. **`_inject_unity_instance`** runs:
|
||||||
|
- Calls `_resolve_user_id()`, which calls `_resolve_user_id_from_request()`.
|
||||||
|
- The request function imports `get_http_headers` from FastMCP and reads the `x-api-key` header.
|
||||||
|
- `ApiKeyService.validate()` checks the cache or calls the external auth endpoint.
|
||||||
|
- If valid, `user_id` is returned. If invalid or missing, `None` is returned.
|
||||||
|
- In remote-hosted mode, `None` causes a `RuntimeError`.
|
||||||
|
|
||||||
|
5. **`user_id` stored in context** via `ctx.set_state("user_id", user_id)`.
|
||||||
|
|
||||||
|
6. **Session key derived** by `get_session_key(ctx)`:
|
||||||
|
- Priority: `client_id` (if available) > `user:{user_id}` > `"global"`.
|
||||||
|
- The `user:{user_id}` fallback ensures session isolation when MCP transports don't provide stable client IDs.
|
||||||
|
|
||||||
|
7. **Active Unity instance looked up** from `_active_by_key` dict using the session key. If none is set, `_maybe_autoselect_instance` is called (but returns `None` in remote-hosted mode).
|
||||||
|
|
||||||
|
8. **Instance injected** via `ctx.set_state("unity_instance", active_instance)`.
|
||||||
|
|
||||||
|
9. **Tool executes**, reading the instance from `ctx.get_state("unity_instance")`.
|
||||||
|
|
||||||
|
10. **Command routed** through `PluginHub.send_command_for_instance(unity_instance, ..., user_id=user_id)`, which resolves the session using `PluginRegistry.get_session_id_by_hash(project_hash, user_id)`.
|
||||||
|
|
||||||
|
## WebSocket Auth Flow
|
||||||
|
|
||||||
|
When a Unity plugin connects via WebSocket:
|
||||||
|
|
||||||
|
```
|
||||||
|
Plugin -> WS /hub/plugin (with X-API-Key header)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
PluginHub.on_connect()
|
||||||
|
|
|
||||||
|
+-- config.http_remote_hosted && ApiKeyService.is_initialized()?
|
||||||
|
| |
|
||||||
|
| +-- No -> accept() (local mode, no auth needed)
|
||||||
|
| |
|
||||||
|
| +-- Yes -> read X-API-Key from headers
|
||||||
|
| |
|
||||||
|
| +-- No key -> close(4401, "API key required")
|
||||||
|
| |
|
||||||
|
| +-- ApiKeyService.validate(key)
|
||||||
|
| |
|
||||||
|
| +-- valid=True -> websocket.state.user_id = user_id
|
||||||
|
| | accept()
|
||||||
|
| |
|
||||||
|
| +-- valid=False, "unavailable" in error
|
||||||
|
| | -> close(1013, "Try again later")
|
||||||
|
| |
|
||||||
|
| +-- valid=False -> close(4403, "Invalid API key")
|
||||||
|
```
|
||||||
|
|
||||||
|
After acceptance, when the plugin sends a `register` message, `_handle_register` reads `websocket.state.user_id` and passes it to `PluginRegistry.register()`.
|
||||||
|
|
||||||
|
## Session Registry Design
|
||||||
|
|
||||||
|
### Local Mode
|
||||||
|
|
||||||
|
```
|
||||||
|
project_hash -> session_id
|
||||||
|
"abc123" -> "uuid-1"
|
||||||
|
"def456" -> "uuid-2"
|
||||||
|
```
|
||||||
|
|
||||||
|
A single `_hash_to_session` dict. Any user can see any session. `list_sessions(user_id=None)` returns all sessions.
|
||||||
|
|
||||||
|
### Remote-Hosted Mode
|
||||||
|
|
||||||
|
```
|
||||||
|
(user_id, project_hash) -> session_id
|
||||||
|
("user-A", "abc123") -> "uuid-1"
|
||||||
|
("user-B", "abc123") -> "uuid-3" (same project, different user)
|
||||||
|
("user-A", "def456") -> "uuid-2"
|
||||||
|
```
|
||||||
|
|
||||||
|
A separate `_user_hash_to_session` dict with composite keys. Two users working on cloned repos (same `project_hash`) get independent sessions.
|
||||||
|
|
||||||
|
### Reconnect Handling
|
||||||
|
|
||||||
|
When a Unity editor reconnects (e.g., after domain reload), `register()` detects the existing mapping for the same key and evicts the old session before inserting the new one. This ensures the latest WebSocket connection always wins.
|
||||||
|
|
||||||
|
### list_sessions Guard
|
||||||
|
|
||||||
|
`list_sessions(user_id=None)` raises `ValueError` when `config.http_remote_hosted` is `True`. This prevents code paths from accidentally listing all users' sessions. Every call site in remote-hosted mode must pass an explicit `user_id`.
|
||||||
|
|
||||||
|
## Caching Strategy
|
||||||
|
|
||||||
|
`ApiKeyService` maintains an in-memory cache:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# api_key -> (valid, user_id, metadata, expires_at)
|
||||||
|
_cache: dict[str, tuple[bool, str | None, dict | None, float]]
|
||||||
|
```
|
||||||
|
|
||||||
|
### What Gets Cached
|
||||||
|
|
||||||
|
| Response | Cached? | Rationale |
|
||||||
|
|----------|---------|-----------|
|
||||||
|
| 200 + `valid: true` | Yes | Definitive valid result |
|
||||||
|
| 200 + `valid: false` | Yes | Definitive invalid result |
|
||||||
|
| 401 status | Yes | Definitive rejection |
|
||||||
|
| 5xx status | No | Transient; retry on next request |
|
||||||
|
| Timeout | No | Transient; retry on next request |
|
||||||
|
| Connection error | No | Transient; retry on next request |
|
||||||
|
| Unexpected exception | No | Transient; retry on next request |
|
||||||
|
|
||||||
|
Non-cacheable results use `ValidationResult(cacheable=False)`.
|
||||||
|
|
||||||
|
### Cache Lifecycle
|
||||||
|
|
||||||
|
- **TTL:** Configurable via `--api-key-cache-ttl` (default: 300 seconds).
|
||||||
|
- **Expiry:** Checked on read. Expired entries are deleted and re-validated.
|
||||||
|
- **Invalidation:** `invalidate_cache(api_key)` removes a single key. `clear_cache()` removes all.
|
||||||
|
- **Concurrency:** Protected by `asyncio.Lock`.
|
||||||
|
|
||||||
|
### Revocation Latency
|
||||||
|
|
||||||
|
A revoked key continues to work for up to `cache_ttl` seconds. Lower the TTL for faster revocation at the cost of more validation requests.
|
||||||
|
|
||||||
|
## Fail-Closed Behaviour
|
||||||
|
|
||||||
|
The system fails closed at every boundary:
|
||||||
|
|
||||||
|
| Component | Failure | Behaviour |
|
||||||
|
|-----------|---------|-----------|
|
||||||
|
| `ApiKeyService._validate_external` | Timeout after retries | `valid=False, cacheable=False` |
|
||||||
|
| `ApiKeyService._validate_external` | Connection error after retries | `valid=False, cacheable=False` |
|
||||||
|
| `ApiKeyService._validate_external` | 5xx status | `valid=False, cacheable=False` |
|
||||||
|
| `ApiKeyService._validate_external` | Unexpected exception | `valid=False, cacheable=False` |
|
||||||
|
| `PluginHub.on_connect` | Auth service unavailable | Close `1013` (retry hint) |
|
||||||
|
| `UnityInstanceMiddleware._inject_unity_instance` | No user_id in remote-hosted mode | `RuntimeError` |
|
||||||
|
|
||||||
|
API keys are never logged in full. Keys longer than 8 characters are redacted to `xxxx...yyyy` in log messages.
|
||||||
|
|
||||||
|
## Session Key Derivation
|
||||||
|
|
||||||
|
`UnityInstanceMiddleware.get_session_key(ctx)` determines which dict key to use for storing/retrieving the active Unity instance per session:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. client_id (string, non-empty) -> return client_id
|
||||||
|
2. ctx.get_state("user_id") -> return "user:{user_id}"
|
||||||
|
3. fallback -> return "global"
|
||||||
|
```
|
||||||
|
|
||||||
|
- **`client_id`:** Stable per MCP client connection. Preferred when available.
|
||||||
|
- **`user:{user_id}`:** Used in remote-hosted mode when the MCP transport doesn't provide a stable client ID. Ensures different users don't share instance selections.
|
||||||
|
- **`"global"`:** Local-dev fallback for single-user scenarios. Unreachable in remote-hosted mode because the auth enforcement raises `RuntimeError` before this point if no `user_id` is available.
|
||||||
|
|
||||||
|
## Disabled Features in Remote-Hosted Mode
|
||||||
|
|
||||||
|
| Feature | Local Mode | Remote-Hosted Mode | Reason |
|
||||||
|
|---------|-----------|-------------------|--------|
|
||||||
|
| Auto-select sole instance | Enabled | Disabled | Implicit behaviour is dangerous with multiple users |
|
||||||
|
| CLI REST routes | Enabled | Disabled | No auth layer on these endpoints |
|
||||||
|
| `list_sessions(user_id=None)` | Returns all | Raises `ValueError` | Prevents accidental cross-user session leaks |
|
||||||
|
|
||||||
|
## Configuration Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
CLI args / env vars
|
||||||
|
|
|
||||||
|
v
|
||||||
|
main.py: parser.parse_args()
|
||||||
|
|
|
||||||
|
+-- config.http_remote_hosted = args or env
|
||||||
|
+-- config.api_key_validation_url = args or env
|
||||||
|
+-- config.api_key_login_url = args or env
|
||||||
|
+-- config.api_key_cache_ttl = args or env (float)
|
||||||
|
+-- config.api_key_service_token_header = args or env
|
||||||
|
+-- config.api_key_service_token = args or env
|
||||||
|
|
|
||||||
|
+-- Validate: remote-hosted requires validation URL
|
||||||
|
| (exits with code 1 if missing)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
create_mcp_server()
|
||||||
|
|
|
||||||
|
+-- get_unity_instance_middleware() -> registers middleware
|
||||||
|
|
|
||||||
|
+-- if remote-hosted + validation URL:
|
||||||
|
| ApiKeyService(
|
||||||
|
| validation_url, cache_ttl,
|
||||||
|
| service_token_header, service_token
|
||||||
|
| )
|
||||||
|
|
|
||||||
|
+-- WebSocketRoute("/hub/plugin", PluginHub)
|
||||||
|
|
|
||||||
|
+-- if not remote-hosted:
|
||||||
|
register CLI routes (/api/command, /api/instances, /api/custom-tools)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
|
||||||
|
| File | Role |
|
||||||
|
|------|------|
|
||||||
|
| `Server/src/core/config.py` | `ServerConfig` dataclass with auth fields |
|
||||||
|
| `Server/src/main.py` | CLI argument parsing, startup validation, service initialization |
|
||||||
|
| `Server/src/services/api_key_service.py` | API key validation singleton with caching and retry |
|
||||||
|
| `Server/src/transport/plugin_hub.py` | WebSocket auth gate, user-scoped session queries |
|
||||||
|
| `Server/src/transport/plugin_registry.py` | Dual-index session storage (local + user-scoped) |
|
||||||
|
| `Server/src/transport/unity_instance_middleware.py` | Per-request user_id and instance injection |
|
||||||
|
| `Server/src/transport/unity_transport.py` | `_resolve_user_id_from_request` helper |
|
||||||
Loading…
Reference in New Issue