Standardize how we define MCP tools (#292)

* refactor: migrate command routing to use CommandRegistry lookup instead of switch statement

* style: improve code formatting and indentation consistency

* refactor: clean up imports and type hints across tool modules

* Revert "feat: Implement Asset Store Compliance for Unity MCP Bridge"

This reverts commit 2fca7fc3da.

* Revert "feat(asset-store): implement post-installation prompt system for Asset Store compliance"

This reverts commit ab25a71bc5.

* chore: upgrade mcp[cli] dependency from 1.4.1 to 1.15.0

* style: fix formatting and whitespace in Python server files

* Remove description, probably a Python versionn change

* feat: add type hints and parameter descriptions to Unity MCP tools

* docs: improve shader management tool parameter descriptions and types

* refactor: add type annotations and improve documentation for script management tools

* refactor: improve type annotations and documentation in manage_scene tool

* refactor: add type annotations and improve parameter descriptions across MCP tools

* feat: add explicit name parameters to all MCP tool decorators

* refactor: remove unused Unity connection instance in manage_asset_tools

* chore: update type hints in manage_editor function parameters for better clarity

* feat: make name and path parameters optional for scene management operations

* refactor: remove unused get_unity_connection import from manage_asset.py

* chore: rename Operation parameter annotation to Operations for consistency

* feat: add logging to MCP clients for tool actions across MCP server components

* chore: add FastMCP type hint to register_all_tools parameter

* style: reformat docstring in apply_text_edits tool to use multiline string syntax

* refactor: update type hints from Dict/List/Tuple/Optional to modern Python syntax

* refactor: clean up imports and add type annotations to script editing tools

* refactor: update type hints to use | None syntax for optional parameters

* Minor fixes

* docs: improve tool descriptions with clearer action explanations

* refactor: remove legacy update action migration code from manage_script.py

* style: replace em dashes with regular hyphens in tool descriptions [skip ci]

* refactor: convert manage_script_capabilities docstring to multiline format [skip ci]
main
Marcus Sanatan 2025-09-27 13:53:10 -04:00 committed by GitHub
parent af4ddf1dd6
commit 5acf10769e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1191 additions and 1209 deletions

View File

@ -1040,27 +1040,7 @@ namespace MCPForUnity.Editor
// Use JObject for parameters as the new handlers likely expect this // Use JObject for parameters as the new handlers likely expect this
JObject paramsObject = command.@params ?? new JObject(); JObject paramsObject = command.@params ?? new JObject();
object result = CommandRegistry.GetHandler(command.type)(paramsObject);
// Route command based on the new tool structure from the refactor plan
object result = command.type switch
{
// Maps the command type (tool name) to the corresponding handler's static HandleCommand method
// Assumes each handler class has a static method named 'HandleCommand' that takes JObject parameters
"manage_script" => ManageScript.HandleCommand(paramsObject),
// Run scene operations on the main thread to avoid deadlocks/hangs (with diagnostics under debug flag)
"manage_scene" => HandleManageScene(paramsObject)
?? throw new TimeoutException($"manage_scene timed out after {FrameIOTimeoutMs} ms on main thread"),
"manage_editor" => ManageEditor.HandleCommand(paramsObject),
"manage_gameobject" => ManageGameObject.HandleCommand(paramsObject),
"manage_asset" => ManageAsset.HandleCommand(paramsObject),
"manage_shader" => ManageShader.HandleCommand(paramsObject),
"read_console" => ReadConsole.HandleCommand(paramsObject),
"manage_menu_item" => ManageMenuItem.HandleCommand(paramsObject),
"manage_prefabs" => ManagePrefabs.HandleCommand(paramsObject),
_ => throw new ArgumentException(
$"Unknown or unsupported command type: {command.type}"
),
};
// Standard success response format // Standard success response format
var response = new { status = "success", result }; var response = new { status = "success", result };

View File

@ -2,6 +2,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
using MCPForUnity.Editor.Tools.MenuItems; using MCPForUnity.Editor.Tools.MenuItems;
using MCPForUnity.Editor.Tools.Prefabs;
namespace MCPForUnity.Editor.Tools namespace MCPForUnity.Editor.Tools
{ {
@ -22,6 +23,7 @@ namespace MCPForUnity.Editor.Tools
{ "read_console", ReadConsole.HandleCommand }, { "read_console", ReadConsole.HandleCommand },
{ "manage_menu_item", ManageMenuItem.HandleCommand }, { "manage_menu_item", ManageMenuItem.HandleCommand },
{ "manage_shader", ManageShader.HandleCommand}, { "manage_shader", ManageShader.HandleCommand},
{ "manage_prefabs", ManagePrefabs.HandleCommand},
}; };
/// <summary> /// <summary>

View File

@ -90,7 +90,7 @@ namespace MCPForUnity.Editor.Tools
return false; return false;
} }
var atAssets = string.Equals( var atAssets = string.Equals(
di.FullName.Replace('\\','/'), di.FullName.Replace('\\', '/'),
assets, assets,
StringComparison.OrdinalIgnoreCase StringComparison.OrdinalIgnoreCase
); );
@ -115,7 +115,7 @@ namespace MCPForUnity.Editor.Tools
{ {
return Response.Error("invalid_params", "Parameters cannot be null."); return Response.Error("invalid_params", "Parameters cannot be null.");
} }
// Extract parameters // Extract parameters
string action = @params["action"]?.ToString()?.ToLower(); string action = @params["action"]?.ToString()?.ToLower();
string name = @params["name"]?.ToString(); string name = @params["name"]?.ToString();
@ -207,81 +207,81 @@ namespace MCPForUnity.Editor.Tools
case "delete": case "delete":
return DeleteScript(fullPath, relativePath); return DeleteScript(fullPath, relativePath);
case "apply_text_edits": case "apply_text_edits":
{ {
var textEdits = @params["edits"] as JArray; var textEdits = @params["edits"] as JArray;
string precondition = @params["precondition_sha256"]?.ToString(); string precondition = @params["precondition_sha256"]?.ToString();
// Respect optional options // Respect optional options
string refreshOpt = @params["options"]?["refresh"]?.ToString()?.ToLowerInvariant(); string refreshOpt = @params["options"]?["refresh"]?.ToString()?.ToLowerInvariant();
string validateOpt = @params["options"]?["validate"]?.ToString()?.ToLowerInvariant(); string validateOpt = @params["options"]?["validate"]?.ToString()?.ToLowerInvariant();
return ApplyTextEdits(fullPath, relativePath, name, textEdits, precondition, refreshOpt, validateOpt); return ApplyTextEdits(fullPath, relativePath, name, textEdits, precondition, refreshOpt, validateOpt);
} }
case "validate": case "validate":
{
string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "standard";
var chosen = level switch
{ {
"basic" => ValidationLevel.Basic, string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "standard";
"standard" => ValidationLevel.Standard, var chosen = level switch
"strict" => ValidationLevel.Strict, {
"comprehensive" => ValidationLevel.Comprehensive, "basic" => ValidationLevel.Basic,
_ => ValidationLevel.Standard "standard" => ValidationLevel.Standard,
}; "strict" => ValidationLevel.Strict,
string fileText; "comprehensive" => ValidationLevel.Comprehensive,
try { fileText = File.ReadAllText(fullPath); } _ => ValidationLevel.Standard
catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } };
string fileText;
try { fileText = File.ReadAllText(fullPath); }
catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); }
bool ok = ValidateScriptSyntax(fileText, chosen, out string[] diagsRaw); bool ok = ValidateScriptSyntax(fileText, chosen, out string[] diagsRaw);
var diags = (diagsRaw ?? Array.Empty<string>()).Select(s => var diags = (diagsRaw ?? Array.Empty<string>()).Select(s =>
{ {
var m = Regex.Match( var m = Regex.Match(
s, s,
@"^(ERROR|WARNING|INFO): (.*?)(?: \(Line (\d+)\))?$", @"^(ERROR|WARNING|INFO): (.*?)(?: \(Line (\d+)\))?$",
RegexOptions.CultureInvariant | RegexOptions.Multiline, RegexOptions.CultureInvariant | RegexOptions.Multiline,
TimeSpan.FromMilliseconds(250) TimeSpan.FromMilliseconds(250)
); );
string severity = m.Success ? m.Groups[1].Value.ToLowerInvariant() : "info"; string severity = m.Success ? m.Groups[1].Value.ToLowerInvariant() : "info";
string message = m.Success ? m.Groups[2].Value : s; string message = m.Success ? m.Groups[2].Value : s;
int lineNum = m.Success && int.TryParse(m.Groups[3].Value, out var l) ? l : 0; int lineNum = m.Success && int.TryParse(m.Groups[3].Value, out var l) ? l : 0;
return new { line = lineNum, col = 0, severity, message }; return new { line = lineNum, col = 0, severity, message };
}).ToArray(); }).ToArray();
var result = new { diagnostics = diags }; var result = new { diagnostics = diags };
return ok ? Response.Success("Validation completed.", result) return ok ? Response.Success("Validation completed.", result)
: Response.Error("Validation failed.", result); : Response.Error("Validation failed.", result);
} }
case "edit": case "edit":
Debug.LogWarning("manage_script.edit is deprecated; prefer apply_text_edits. Serving structured edit for backward compatibility."); Debug.LogWarning("manage_script.edit is deprecated; prefer apply_text_edits. Serving structured edit for backward compatibility.");
var structEdits = @params["edits"] as JArray; var structEdits = @params["edits"] as JArray;
var options = @params["options"] as JObject; var options = @params["options"] as JObject;
return EditScript(fullPath, relativePath, name, structEdits, options); return EditScript(fullPath, relativePath, name, structEdits, options);
case "get_sha": case "get_sha":
{
try
{ {
if (!File.Exists(fullPath)) try
return Response.Error($"Script not found at '{relativePath}'.");
string text = File.ReadAllText(fullPath);
string sha = ComputeSha256(text);
var fi = new FileInfo(fullPath);
long lengthBytes;
try { lengthBytes = new System.Text.UTF8Encoding(encoderShouldEmitUTF8Identifier: false).GetByteCount(text); }
catch { lengthBytes = fi.Exists ? fi.Length : 0; }
var data = new
{ {
uri = $"unity://path/{relativePath}", if (!File.Exists(fullPath))
path = relativePath, return Response.Error($"Script not found at '{relativePath}'.");
sha256 = sha,
lengthBytes, string text = File.ReadAllText(fullPath);
lastModifiedUtc = fi.Exists ? fi.LastWriteTimeUtc.ToString("o") : string.Empty string sha = ComputeSha256(text);
}; var fi = new FileInfo(fullPath);
return Response.Success($"SHA computed for '{relativePath}'.", data); long lengthBytes;
try { lengthBytes = new System.Text.UTF8Encoding(encoderShouldEmitUTF8Identifier: false).GetByteCount(text); }
catch { lengthBytes = fi.Exists ? fi.Length : 0; }
var data = new
{
uri = $"unity://path/{relativePath}",
path = relativePath,
sha256 = sha,
lengthBytes,
lastModifiedUtc = fi.Exists ? fi.LastWriteTimeUtc.ToString("o") : string.Empty
};
return Response.Success($"SHA computed for '{relativePath}'.", data);
}
catch (Exception ex)
{
return Response.Error($"Failed to compute SHA: {ex.Message}");
}
} }
catch (Exception ex)
{
return Response.Error($"Failed to compute SHA: {ex.Message}");
}
}
default: default:
return Response.Error( return Response.Error(
$"Unknown action: '{action}'. Valid actions are: create, delete, apply_text_edits, validate, read (deprecated), update (deprecated), edit (deprecated)." $"Unknown action: '{action}'. Valid actions are: create, delete, apply_text_edits, validate, read (deprecated), update (deprecated), edit (deprecated)."
@ -505,7 +505,7 @@ namespace MCPForUnity.Editor.Tools
try try
{ {
var di = new DirectoryInfo(Path.GetDirectoryName(fullPath) ?? ""); var di = new DirectoryInfo(Path.GetDirectoryName(fullPath) ?? "");
while (di != null && !string.Equals(di.FullName.Replace('\\','/'), Application.dataPath.Replace('\\','/'), StringComparison.OrdinalIgnoreCase)) while (di != null && !string.Equals(di.FullName.Replace('\\', '/'), Application.dataPath.Replace('\\', '/'), StringComparison.OrdinalIgnoreCase))
{ {
if (di.Exists && (di.Attributes & FileAttributes.ReparsePoint) != 0) if (di.Exists && (di.Attributes & FileAttributes.ReparsePoint) != 0)
return Response.Error("Refusing to edit a symlinked script path."); return Response.Error("Refusing to edit a symlinked script path.");
@ -640,7 +640,7 @@ namespace MCPForUnity.Editor.Tools
}; };
structEdits.Add(op); structEdits.Add(op);
// Reuse structured path // Reuse structured path
return EditScript(fullPath, relativePath, name, structEdits, new JObject{ ["refresh"] = "immediate", ["validate"] = "standard" }); return EditScript(fullPath, relativePath, name, structEdits, new JObject { ["refresh"] = "immediate", ["validate"] = "standard" });
} }
} }
} }
@ -656,7 +656,7 @@ namespace MCPForUnity.Editor.Tools
spans = spans.OrderByDescending(t => t.start).ToList(); spans = spans.OrderByDescending(t => t.start).ToList();
for (int i = 1; i < spans.Count; i++) for (int i = 1; i < spans.Count; i++)
{ {
if (spans[i].end > spans[i - 1].start) if (spans[i].end > spans[i - 1].start)
{ {
var conflict = new[] { new { startA = spans[i].start, endA = spans[i].end, startB = spans[i - 1].start, endB = spans[i - 1].end } }; var conflict = new[] { new { startA = spans[i].start, endA = spans[i].end, startB = spans[i - 1].start, endB = spans[i - 1].end } };
return Response.Error("overlap", new { status = "overlap", conflicts = conflict, hint = "Sort ranges descending by start and compute from the same snapshot." }); return Response.Error("overlap", new { status = "overlap", conflicts = conflict, hint = "Sort ranges descending by start and compute from the same snapshot." });
@ -942,8 +942,10 @@ namespace MCPForUnity.Editor.Tools
if (c == '\'') { inChr = true; esc = false; continue; } if (c == '\'') { inChr = true; esc = false; continue; }
if (c == '/' && n == '/') { while (i < end && text[i] != '\n') i++; continue; } if (c == '/' && n == '/') { while (i < end && text[i] != '\n') i++; continue; }
if (c == '/' && n == '*') { i += 2; while (i + 1 < end && !(text[i] == '*' && text[i + 1] == '/')) i++; i++; continue; } if (c == '/' && n == '*') { i += 2; while (i + 1 < end && !(text[i] == '*' && text[i + 1] == '/')) i++; i++; continue; }
if (c == '{') brace++; else if (c == '}') brace--; if (c == '{') brace++;
else if (c == '(') paren++; else if (c == ')') paren--; else if (c == '}') brace--;
else if (c == '(') paren++;
else if (c == ')') paren--;
else if (c == '[') bracket++; else if (c == ']') bracket--; else if (c == '[') bracket++; else if (c == ']') bracket--;
// Allow temporary negative balance - will check tolerance at end // Allow temporary negative balance - will check tolerance at end
} }
@ -1035,291 +1037,291 @@ namespace MCPForUnity.Editor.Tools
switch (mode) switch (mode)
{ {
case "replace_class": case "replace_class":
{
string className = op.Value<string>("className");
string ns = op.Value<string>("namespace");
string replacement = ExtractReplacement(op);
if (string.IsNullOrWhiteSpace(className))
return Response.Error("replace_class requires 'className'.");
if (replacement == null)
return Response.Error("replace_class requires 'replacement' (inline or base64).");
if (!TryComputeClassSpan(working, className, ns, out var spanStart, out var spanLength, out var why))
return Response.Error($"replace_class failed: {why}");
if (!ValidateClassSnippet(replacement, className, out var vErr))
return Response.Error($"Replacement snippet invalid: {vErr}");
if (applySequentially)
{ {
working = working.Remove(spanStart, spanLength).Insert(spanStart, NormalizeNewlines(replacement)); string className = op.Value<string>("className");
appliedCount++; string ns = op.Value<string>("namespace");
string replacement = ExtractReplacement(op);
if (string.IsNullOrWhiteSpace(className))
return Response.Error("replace_class requires 'className'.");
if (replacement == null)
return Response.Error("replace_class requires 'replacement' (inline or base64).");
if (!TryComputeClassSpan(working, className, ns, out var spanStart, out var spanLength, out var why))
return Response.Error($"replace_class failed: {why}");
if (!ValidateClassSnippet(replacement, className, out var vErr))
return Response.Error($"Replacement snippet invalid: {vErr}");
if (applySequentially)
{
working = working.Remove(spanStart, spanLength).Insert(spanStart, NormalizeNewlines(replacement));
appliedCount++;
}
else
{
replacements.Add((spanStart, spanLength, NormalizeNewlines(replacement)));
}
break;
} }
else
{
replacements.Add((spanStart, spanLength, NormalizeNewlines(replacement)));
}
break;
}
case "delete_class": case "delete_class":
{
string className = op.Value<string>("className");
string ns = op.Value<string>("namespace");
if (string.IsNullOrWhiteSpace(className))
return Response.Error("delete_class requires 'className'.");
if (!TryComputeClassSpan(working, className, ns, out var s, out var l, out var why))
return Response.Error($"delete_class failed: {why}");
if (applySequentially)
{ {
working = working.Remove(s, l); string className = op.Value<string>("className");
appliedCount++; string ns = op.Value<string>("namespace");
if (string.IsNullOrWhiteSpace(className))
return Response.Error("delete_class requires 'className'.");
if (!TryComputeClassSpan(working, className, ns, out var s, out var l, out var why))
return Response.Error($"delete_class failed: {why}");
if (applySequentially)
{
working = working.Remove(s, l);
appliedCount++;
}
else
{
replacements.Add((s, l, string.Empty));
}
break;
} }
else
{
replacements.Add((s, l, string.Empty));
}
break;
}
case "replace_method": case "replace_method":
{
string className = op.Value<string>("className");
string ns = op.Value<string>("namespace");
string methodName = op.Value<string>("methodName");
string replacement = ExtractReplacement(op);
string returnType = op.Value<string>("returnType");
string parametersSignature = op.Value<string>("parametersSignature");
string attributesContains = op.Value<string>("attributesContains");
if (string.IsNullOrWhiteSpace(className)) return Response.Error("replace_method requires 'className'.");
if (string.IsNullOrWhiteSpace(methodName)) return Response.Error("replace_method requires 'methodName'.");
if (replacement == null) return Response.Error("replace_method requires 'replacement' (inline or base64).");
if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass))
return Response.Error($"replace_method failed to locate class: {whyClass}");
if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod))
{ {
bool hasDependentInsert = edits.Any(j => j is JObject jo && string className = op.Value<string>("className");
string.Equals(jo.Value<string>("className"), className, StringComparison.Ordinal) && string ns = op.Value<string>("namespace");
string.Equals(jo.Value<string>("methodName"), methodName, StringComparison.Ordinal) && string methodName = op.Value<string>("methodName");
((jo.Value<string>("mode") ?? jo.Value<string>("op") ?? string.Empty).ToLowerInvariant() == "insert_method")); string replacement = ExtractReplacement(op);
string hint = hasDependentInsert && !applySequentially ? " Hint: This batch inserts this method. Use options.applyMode='sequential' or split into separate calls." : string.Empty; string returnType = op.Value<string>("returnType");
return Response.Error($"replace_method failed: {whyMethod}.{hint}"); string parametersSignature = op.Value<string>("parametersSignature");
} string attributesContains = op.Value<string>("attributesContains");
if (applySequentially) if (string.IsNullOrWhiteSpace(className)) return Response.Error("replace_method requires 'className'.");
{ if (string.IsNullOrWhiteSpace(methodName)) return Response.Error("replace_method requires 'methodName'.");
working = working.Remove(mStart, mLen).Insert(mStart, NormalizeNewlines(replacement)); if (replacement == null) return Response.Error("replace_method requires 'replacement' (inline or base64).");
appliedCount++;
if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass))
return Response.Error($"replace_method failed to locate class: {whyClass}");
if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod))
{
bool hasDependentInsert = edits.Any(j => j is JObject jo &&
string.Equals(jo.Value<string>("className"), className, StringComparison.Ordinal) &&
string.Equals(jo.Value<string>("methodName"), methodName, StringComparison.Ordinal) &&
((jo.Value<string>("mode") ?? jo.Value<string>("op") ?? string.Empty).ToLowerInvariant() == "insert_method"));
string hint = hasDependentInsert && !applySequentially ? " Hint: This batch inserts this method. Use options.applyMode='sequential' or split into separate calls." : string.Empty;
return Response.Error($"replace_method failed: {whyMethod}.{hint}");
}
if (applySequentially)
{
working = working.Remove(mStart, mLen).Insert(mStart, NormalizeNewlines(replacement));
appliedCount++;
}
else
{
replacements.Add((mStart, mLen, NormalizeNewlines(replacement)));
}
break;
} }
else
{
replacements.Add((mStart, mLen, NormalizeNewlines(replacement)));
}
break;
}
case "delete_method": case "delete_method":
{
string className = op.Value<string>("className");
string ns = op.Value<string>("namespace");
string methodName = op.Value<string>("methodName");
string returnType = op.Value<string>("returnType");
string parametersSignature = op.Value<string>("parametersSignature");
string attributesContains = op.Value<string>("attributesContains");
if (string.IsNullOrWhiteSpace(className)) return Response.Error("delete_method requires 'className'.");
if (string.IsNullOrWhiteSpace(methodName)) return Response.Error("delete_method requires 'methodName'.");
if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass))
return Response.Error($"delete_method failed to locate class: {whyClass}");
if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod))
{ {
bool hasDependentInsert = edits.Any(j => j is JObject jo && string className = op.Value<string>("className");
string.Equals(jo.Value<string>("className"), className, StringComparison.Ordinal) && string ns = op.Value<string>("namespace");
string.Equals(jo.Value<string>("methodName"), methodName, StringComparison.Ordinal) && string methodName = op.Value<string>("methodName");
((jo.Value<string>("mode") ?? jo.Value<string>("op") ?? string.Empty).ToLowerInvariant() == "insert_method")); string returnType = op.Value<string>("returnType");
string hint = hasDependentInsert && !applySequentially ? " Hint: This batch inserts this method. Use options.applyMode='sequential' or split into separate calls." : string.Empty; string parametersSignature = op.Value<string>("parametersSignature");
return Response.Error($"delete_method failed: {whyMethod}.{hint}"); string attributesContains = op.Value<string>("attributesContains");
}
if (applySequentially) if (string.IsNullOrWhiteSpace(className)) return Response.Error("delete_method requires 'className'.");
{ if (string.IsNullOrWhiteSpace(methodName)) return Response.Error("delete_method requires 'methodName'.");
working = working.Remove(mStart, mLen);
appliedCount++; if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass))
return Response.Error($"delete_method failed to locate class: {whyClass}");
if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod))
{
bool hasDependentInsert = edits.Any(j => j is JObject jo &&
string.Equals(jo.Value<string>("className"), className, StringComparison.Ordinal) &&
string.Equals(jo.Value<string>("methodName"), methodName, StringComparison.Ordinal) &&
((jo.Value<string>("mode") ?? jo.Value<string>("op") ?? string.Empty).ToLowerInvariant() == "insert_method"));
string hint = hasDependentInsert && !applySequentially ? " Hint: This batch inserts this method. Use options.applyMode='sequential' or split into separate calls." : string.Empty;
return Response.Error($"delete_method failed: {whyMethod}.{hint}");
}
if (applySequentially)
{
working = working.Remove(mStart, mLen);
appliedCount++;
}
else
{
replacements.Add((mStart, mLen, string.Empty));
}
break;
} }
else
{
replacements.Add((mStart, mLen, string.Empty));
}
break;
}
case "insert_method": case "insert_method":
{
string className = op.Value<string>("className");
string ns = op.Value<string>("namespace");
string position = (op.Value<string>("position") ?? "end").ToLowerInvariant();
string afterMethodName = op.Value<string>("afterMethodName");
string afterReturnType = op.Value<string>("afterReturnType");
string afterParameters = op.Value<string>("afterParametersSignature");
string afterAttributesContains = op.Value<string>("afterAttributesContains");
string snippet = ExtractReplacement(op);
// Harden: refuse empty replacement for inserts
if (snippet == null || snippet.Trim().Length == 0)
return Response.Error("insert_method requires a non-empty 'replacement' text.");
if (string.IsNullOrWhiteSpace(className)) return Response.Error("insert_method requires 'className'.");
if (snippet == null) return Response.Error("insert_method requires 'replacement' (inline or base64) containing a full method declaration.");
if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass))
return Response.Error($"insert_method failed to locate class: {whyClass}");
if (position == "after")
{ {
if (string.IsNullOrEmpty(afterMethodName)) return Response.Error("insert_method with position='after' requires 'afterMethodName'."); string className = op.Value<string>("className");
if (!TryComputeMethodSpan(working, clsStart, clsLen, afterMethodName, afterReturnType, afterParameters, afterAttributesContains, out var aStart, out var aLen, out var whyAfter)) string ns = op.Value<string>("namespace");
return Response.Error($"insert_method(after) failed to locate anchor method: {whyAfter}"); string position = (op.Value<string>("position") ?? "end").ToLowerInvariant();
int insAt = aStart + aLen; string afterMethodName = op.Value<string>("afterMethodName");
string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n"); string afterReturnType = op.Value<string>("afterReturnType");
if (applySequentially) string afterParameters = op.Value<string>("afterParametersSignature");
{ string afterAttributesContains = op.Value<string>("afterAttributesContains");
working = working.Insert(insAt, text); string snippet = ExtractReplacement(op);
appliedCount++; // Harden: refuse empty replacement for inserts
} if (snippet == null || snippet.Trim().Length == 0)
else return Response.Error("insert_method requires a non-empty 'replacement' text.");
{
replacements.Add((insAt, 0, text));
}
}
else if (!TryFindClassInsertionPoint(working, clsStart, clsLen, position, out var insAt, out var whyIns))
return Response.Error($"insert_method failed: {whyIns}");
else
{
string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n");
if (applySequentially)
{
working = working.Insert(insAt, text);
appliedCount++;
}
else
{
replacements.Add((insAt, 0, text));
}
}
break;
}
case "anchor_insert": if (string.IsNullOrWhiteSpace(className)) return Response.Error("insert_method requires 'className'.");
{ if (snippet == null) return Response.Error("insert_method requires 'replacement' (inline or base64) containing a full method declaration.");
string anchor = op.Value<string>("anchor");
string position = (op.Value<string>("position") ?? "before").ToLowerInvariant();
string text = op.Value<string>("text") ?? ExtractReplacement(op);
if (string.IsNullOrWhiteSpace(anchor)) return Response.Error("anchor_insert requires 'anchor' (regex).");
if (string.IsNullOrEmpty(text)) return Response.Error("anchor_insert requires non-empty 'text'.");
try if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass))
{ return Response.Error($"insert_method failed to locate class: {whyClass}");
var rx = new Regex(anchor, RegexOptions.Multiline, TimeSpan.FromSeconds(2));
var m = rx.Match(working);
if (!m.Success) return Response.Error($"anchor_insert: anchor not found: {anchor}");
int insAt = position == "after" ? m.Index + m.Length : m.Index;
string norm = NormalizeNewlines(text);
if (!norm.EndsWith("\n"))
{
norm += "\n";
}
// Duplicate guard: if identical snippet already exists within this class, skip insert if (position == "after")
if (TryComputeClassSpan(working, name, null, out var clsStartDG, out var clsLenDG, out _))
{ {
string classSlice = working.Substring(clsStartDG, Math.Min(clsLenDG, working.Length - clsStartDG)); if (string.IsNullOrEmpty(afterMethodName)) return Response.Error("insert_method with position='after' requires 'afterMethodName'.");
if (classSlice.IndexOf(norm, StringComparison.Ordinal) >= 0) if (!TryComputeMethodSpan(working, clsStart, clsLen, afterMethodName, afterReturnType, afterParameters, afterAttributesContains, out var aStart, out var aLen, out var whyAfter))
return Response.Error($"insert_method(after) failed to locate anchor method: {whyAfter}");
int insAt = aStart + aLen;
string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n");
if (applySequentially)
{ {
// Do not insert duplicate; treat as no-op working = working.Insert(insAt, text);
break; appliedCount++;
}
else
{
replacements.Add((insAt, 0, text));
} }
} }
if (applySequentially) else if (!TryFindClassInsertionPoint(working, clsStart, clsLen, position, out var insAt, out var whyIns))
{ return Response.Error($"insert_method failed: {whyIns}");
working = working.Insert(insAt, norm);
appliedCount++;
}
else else
{ {
replacements.Add((insAt, 0, norm)); string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n");
if (applySequentially)
{
working = working.Insert(insAt, text);
appliedCount++;
}
else
{
replacements.Add((insAt, 0, text));
}
} }
break;
} }
catch (Exception ex)
case "anchor_insert":
{ {
return Response.Error($"anchor_insert failed: {ex.Message}"); string anchor = op.Value<string>("anchor");
string position = (op.Value<string>("position") ?? "before").ToLowerInvariant();
string text = op.Value<string>("text") ?? ExtractReplacement(op);
if (string.IsNullOrWhiteSpace(anchor)) return Response.Error("anchor_insert requires 'anchor' (regex).");
if (string.IsNullOrEmpty(text)) return Response.Error("anchor_insert requires non-empty 'text'.");
try
{
var rx = new Regex(anchor, RegexOptions.Multiline, TimeSpan.FromSeconds(2));
var m = rx.Match(working);
if (!m.Success) return Response.Error($"anchor_insert: anchor not found: {anchor}");
int insAt = position == "after" ? m.Index + m.Length : m.Index;
string norm = NormalizeNewlines(text);
if (!norm.EndsWith("\n"))
{
norm += "\n";
}
// Duplicate guard: if identical snippet already exists within this class, skip insert
if (TryComputeClassSpan(working, name, null, out var clsStartDG, out var clsLenDG, out _))
{
string classSlice = working.Substring(clsStartDG, Math.Min(clsLenDG, working.Length - clsStartDG));
if (classSlice.IndexOf(norm, StringComparison.Ordinal) >= 0)
{
// Do not insert duplicate; treat as no-op
break;
}
}
if (applySequentially)
{
working = working.Insert(insAt, norm);
appliedCount++;
}
else
{
replacements.Add((insAt, 0, norm));
}
}
catch (Exception ex)
{
return Response.Error($"anchor_insert failed: {ex.Message}");
}
break;
} }
break;
}
case "anchor_delete": case "anchor_delete":
{
string anchor = op.Value<string>("anchor");
if (string.IsNullOrWhiteSpace(anchor)) return Response.Error("anchor_delete requires 'anchor' (regex).");
try
{ {
var rx = new Regex(anchor, RegexOptions.Multiline, TimeSpan.FromSeconds(2)); string anchor = op.Value<string>("anchor");
var m = rx.Match(working); if (string.IsNullOrWhiteSpace(anchor)) return Response.Error("anchor_delete requires 'anchor' (regex).");
if (!m.Success) return Response.Error($"anchor_delete: anchor not found: {anchor}"); try
int delAt = m.Index;
int delLen = m.Length;
if (applySequentially)
{ {
working = working.Remove(delAt, delLen); var rx = new Regex(anchor, RegexOptions.Multiline, TimeSpan.FromSeconds(2));
appliedCount++; var m = rx.Match(working);
if (!m.Success) return Response.Error($"anchor_delete: anchor not found: {anchor}");
int delAt = m.Index;
int delLen = m.Length;
if (applySequentially)
{
working = working.Remove(delAt, delLen);
appliedCount++;
}
else
{
replacements.Add((delAt, delLen, string.Empty));
}
} }
else catch (Exception ex)
{ {
replacements.Add((delAt, delLen, string.Empty)); return Response.Error($"anchor_delete failed: {ex.Message}");
} }
break;
} }
catch (Exception ex)
{
return Response.Error($"anchor_delete failed: {ex.Message}");
}
break;
}
case "anchor_replace": case "anchor_replace":
{
string anchor = op.Value<string>("anchor");
string replacement = op.Value<string>("text") ?? op.Value<string>("replacement") ?? ExtractReplacement(op) ?? string.Empty;
if (string.IsNullOrWhiteSpace(anchor)) return Response.Error("anchor_replace requires 'anchor' (regex).");
try
{ {
var rx = new Regex(anchor, RegexOptions.Multiline, TimeSpan.FromSeconds(2)); string anchor = op.Value<string>("anchor");
var m = rx.Match(working); string replacement = op.Value<string>("text") ?? op.Value<string>("replacement") ?? ExtractReplacement(op) ?? string.Empty;
if (!m.Success) return Response.Error($"anchor_replace: anchor not found: {anchor}"); if (string.IsNullOrWhiteSpace(anchor)) return Response.Error("anchor_replace requires 'anchor' (regex).");
int at = m.Index; try
int len = m.Length;
string norm = NormalizeNewlines(replacement);
if (applySequentially)
{ {
working = working.Remove(at, len).Insert(at, norm); var rx = new Regex(anchor, RegexOptions.Multiline, TimeSpan.FromSeconds(2));
appliedCount++; var m = rx.Match(working);
if (!m.Success) return Response.Error($"anchor_replace: anchor not found: {anchor}");
int at = m.Index;
int len = m.Length;
string norm = NormalizeNewlines(replacement);
if (applySequentially)
{
working = working.Remove(at, len).Insert(at, norm);
appliedCount++;
}
else
{
replacements.Add((at, len, norm));
}
} }
else catch (Exception ex)
{ {
replacements.Add((at, len, norm)); return Response.Error($"anchor_replace failed: {ex.Message}");
} }
break;
} }
catch (Exception ex)
{
return Response.Error($"anchor_replace failed: {ex.Message}");
}
break;
}
default: default:
return Response.Error($"Unknown edit mode: '{mode}'. Allowed: replace_class, delete_class, replace_method, delete_method, insert_method, anchor_insert, anchor_delete, anchor_replace."); return Response.Error($"Unknown edit mode: '{mode}'. Allowed: replace_class, delete_class, replace_method, delete_method, insert_method, anchor_insert, anchor_delete, anchor_replace.");
@ -1703,7 +1705,7 @@ namespace MCPForUnity.Editor.Tools
} }
// Tolerate generic constraints between params and body: multiple 'where T : ...' // Tolerate generic constraints between params and body: multiple 'where T : ...'
for (;;) for (; ; )
{ {
// Skip whitespace/comments before checking for 'where' // Skip whitespace/comments before checking for 'where'
for (; i < searchEnd; i++) for (; i < searchEnd; i++)

View File

@ -24,4 +24,4 @@ RUN uv pip install --system -e .
# Command to run the server # Command to run the server
CMD ["uv", "run", "server.py"] CMD ["uv", "run", "server.py"]

View File

@ -1,3 +1,3 @@
""" """
MCP for Unity Server package. MCP for Unity Server package.
""" """

View File

@ -5,26 +5,30 @@ This file contains all configurable parameters for the server.
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class ServerConfig: class ServerConfig:
"""Main configuration class for the MCP server.""" """Main configuration class for the MCP server."""
# Network settings # Network settings
unity_host: str = "localhost" unity_host: str = "localhost"
unity_port: int = 6400 unity_port: int = 6400
mcp_port: int = 6500 mcp_port: int = 6500
# Connection settings # Connection settings
connection_timeout: float = 1.0 # short initial timeout; retries use shorter timeouts # short initial timeout; retries use shorter timeouts
connection_timeout: float = 1.0
buffer_size: int = 16 * 1024 * 1024 # 16MB buffer buffer_size: int = 16 * 1024 * 1024 # 16MB buffer
# Framed receive behavior # Framed receive behavior
framed_receive_timeout: float = 2.0 # max seconds to wait while consuming heartbeats only # max seconds to wait while consuming heartbeats only
max_heartbeat_frames: int = 16 # cap heartbeat frames consumed before giving up framed_receive_timeout: float = 2.0
# cap heartbeat frames consumed before giving up
max_heartbeat_frames: int = 16
# Logging settings # Logging settings
log_level: str = "INFO" log_level: str = "INFO"
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# Server settings # Server settings
max_retries: int = 10 max_retries: int = 10
retry_delay: float = 0.25 retry_delay: float = 0.25
@ -33,11 +37,12 @@ class ServerConfig:
# Number of polite retries when Unity reports reloading # Number of polite retries when Unity reports reloading
# 40 × 250ms ≈ 10s default window # 40 × 250ms ≈ 10s default window
reload_max_retries: int = 40 reload_max_retries: int = 40
# Telemetry settings # Telemetry settings
telemetry_enabled: bool = True telemetry_enabled: bool = True
# Align with telemetry.py default Cloud Run endpoint # Align with telemetry.py default Cloud Run endpoint
telemetry_endpoint: str = "https://api-prod.coplay.dev/telemetry/events" telemetry_endpoint: str = "https://api-prod.coplay.dev/telemetry/events"
# Create a global config instance # Create a global config instance
config = ServerConfig() config = ServerConfig()

View File

@ -11,31 +11,31 @@ What changed and why:
(quick socket connect + ping) before choosing it. (quick socket connect + ping) before choosing it.
""" """
import glob
import json import json
import os
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional, List
import glob
import socket import socket
from typing import Optional, List
logger = logging.getLogger("mcp-for-unity-server") logger = logging.getLogger("mcp-for-unity-server")
class PortDiscovery: class PortDiscovery:
"""Handles port discovery from Unity Bridge registry""" """Handles port discovery from Unity Bridge registry"""
REGISTRY_FILE = "unity-mcp-port.json" # legacy single-project file REGISTRY_FILE = "unity-mcp-port.json" # legacy single-project file
DEFAULT_PORT = 6400 DEFAULT_PORT = 6400
CONNECT_TIMEOUT = 0.3 # seconds, keep this snappy during discovery CONNECT_TIMEOUT = 0.3 # seconds, keep this snappy during discovery
@staticmethod @staticmethod
def get_registry_path() -> Path: def get_registry_path() -> Path:
"""Get the path to the port registry file""" """Get the path to the port registry file"""
return Path.home() / ".unity-mcp" / PortDiscovery.REGISTRY_FILE return Path.home() / ".unity-mcp" / PortDiscovery.REGISTRY_FILE
@staticmethod @staticmethod
def get_registry_dir() -> Path: def get_registry_dir() -> Path:
return Path.home() / ".unity-mcp" return Path.home() / ".unity-mcp"
@staticmethod @staticmethod
def list_candidate_files() -> List[Path]: def list_candidate_files() -> List[Path]:
"""Return candidate registry files, newest first. """Return candidate registry files, newest first.
@ -52,7 +52,7 @@ class PortDiscovery:
# Put legacy at the end so hashed, per-project files win # Put legacy at the end so hashed, per-project files win
hashed.append(legacy) hashed.append(legacy)
return hashed return hashed
@staticmethod @staticmethod
def _try_probe_unity_mcp(port: int) -> bool: def _try_probe_unity_mcp(port: int) -> bool:
"""Quickly check if a MCP for Unity listener is on this port. """Quickly check if a MCP for Unity listener is on this port.
@ -78,7 +78,8 @@ class PortDiscovery:
try: try:
base = PortDiscovery.get_registry_dir() base = PortDiscovery.get_registry_dir()
status_files = sorted( status_files = sorted(
(Path(p) for p in glob.glob(str(base / "unity-mcp-status-*.json"))), (Path(p)
for p in glob.glob(str(base / "unity-mcp-status-*.json"))),
key=lambda p: p.stat().st_mtime, key=lambda p: p.stat().st_mtime,
reverse=True, reverse=True,
) )
@ -88,14 +89,14 @@ class PortDiscovery:
return json.load(f) return json.load(f)
except Exception: except Exception:
return None return None
@staticmethod @staticmethod
def discover_unity_port() -> int: def discover_unity_port() -> int:
""" """
Discover Unity port by scanning per-project and legacy registry files. Discover Unity port by scanning per-project and legacy registry files.
Prefer the newest file whose port responds; fall back to first parsed Prefer the newest file whose port responds; fall back to first parsed
value; finally default to 6400. value; finally default to 6400.
Returns: Returns:
Port number to connect to Port number to connect to
""" """
@ -120,26 +121,29 @@ class PortDiscovery:
if first_seen_port is None: if first_seen_port is None:
first_seen_port = unity_port first_seen_port = unity_port
if PortDiscovery._try_probe_unity_mcp(unity_port): if PortDiscovery._try_probe_unity_mcp(unity_port):
logger.info(f"Using Unity port from {path.name}: {unity_port}") logger.info(
f"Using Unity port from {path.name}: {unity_port}")
return unity_port return unity_port
except Exception as e: except Exception as e:
logger.warning(f"Could not read port registry {path}: {e}") logger.warning(f"Could not read port registry {path}: {e}")
if first_seen_port is not None: if first_seen_port is not None:
logger.info(f"No responsive port found; using first seen value {first_seen_port}") logger.info(
f"No responsive port found; using first seen value {first_seen_port}")
return first_seen_port return first_seen_port
# Fallback to default port # Fallback to default port
logger.info(f"No port registry found; using default port {PortDiscovery.DEFAULT_PORT}") logger.info(
f"No port registry found; using default port {PortDiscovery.DEFAULT_PORT}")
return PortDiscovery.DEFAULT_PORT return PortDiscovery.DEFAULT_PORT
@staticmethod @staticmethod
def get_port_config() -> Optional[dict]: def get_port_config() -> Optional[dict]:
""" """
Get the most relevant port configuration from registry. Get the most relevant port configuration from registry.
Returns the most recent hashed file's config if present, Returns the most recent hashed file's config if present,
otherwise the legacy file's config. Returns None if nothing exists. otherwise the legacy file's config. Returns None if nothing exists.
Returns: Returns:
Port configuration dict or None if not found Port configuration dict or None if not found
""" """
@ -151,5 +155,6 @@ class PortDiscovery:
with open(path, 'r') as f: with open(path, 'r') as f:
return json.load(f) return json.load(f)
except Exception as e: except Exception as e:
logger.warning(f"Could not read port configuration {path}: {e}") logger.warning(
return None f"Could not read port configuration {path}: {e}")
return None

View File

@ -4,7 +4,7 @@ version = "4.0.0"
description = "MCP for Unity Server: A Unity package for Unity Editor integration via the Model Context Protocol (MCP)." description = "MCP for Unity Server: A Unity package for Unity Editor integration via the Model Context Protocol (MCP)."
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = ["httpx>=0.27.2", "mcp[cli]>=1.4.1"] dependencies = ["httpx>=0.27.2", "mcp[cli]>=1.15.0"]
[build-system] [build-system]
requires = ["setuptools>=64.0.0", "wheel"] requires = ["setuptools>=64.0.0", "wheel"]

View File

@ -4,5 +4,6 @@ Deprecated: Sentinel flipping is handled inside Unity via the MCP menu
All functions are no-ops to prevent accidental external writes. All functions are no-ops to prevent accidental external writes.
""" """
def flip_reload_sentinel(*args, **kwargs) -> str: def flip_reload_sentinel(*args, **kwargs) -> str:
return "reload_sentinel.py is deprecated; use execute_menu_item → 'MCP/Flip Reload Sentinel'" return "reload_sentinel.py is deprecated; use execute_menu_item → 'MCP/Flip Reload Sentinel'"

View File

@ -1,10 +1,9 @@
from mcp.server.fastmcp import FastMCP, Context, Image from mcp.server.fastmcp import FastMCP
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import os import os
from dataclasses import dataclass
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator, Dict, Any, List from typing import AsyncIterator, Dict, Any
from config import config from config import config
from tools import register_all_tools from tools import register_all_tools
from unity_connection import get_unity_connection, UnityConnection from unity_connection import get_unity_connection, UnityConnection
@ -150,8 +149,7 @@ async def server_lifespan(server: FastMCP) -> AsyncIterator[Dict[str, Any]]:
# Initialize MCP server # Initialize MCP server
mcp = FastMCP( mcp = FastMCP(
"mcp-for-unity-server", name="mcp-for-unity-server",
description="Unity Editor integration via Model Context Protocol",
lifespan=server_lifespan lifespan=server_lifespan
) )

View File

@ -1,29 +1,28 @@
""" """
Privacy-focused, anonymous telemetry system for Unity MCP Privacy-focused, anonymous telemetry system for Unity MCP
Inspired by Onyx's telemetry implementation with Unity-specific adaptations Inspired by Onyx's telemetry implementation with Unity-specific adaptations
"""
import uuid
import threading
"""
Fire-and-forget telemetry sender with a single background worker. Fire-and-forget telemetry sender with a single background worker.
- No context/thread-local propagation to avoid re-entrancy into tool resolution. - No context/thread-local propagation to avoid re-entrancy into tool resolution.
- Small network timeouts to prevent stalls. - Small network timeouts to prevent stalls.
""" """
import json
import time
import os
import sys
import platform
import logging
from enum import Enum
from urllib.parse import urlparse
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any, List
from pathlib import Path
import importlib
import queue
import contextlib import contextlib
from dataclasses import dataclass
from enum import Enum
import importlib
import json
import logging
import os
from pathlib import Path
import platform
import queue
import sys
import threading
import time
from typing import Optional, Dict, Any
from urllib.parse import urlparse
import uuid
try: try:
import httpx import httpx
@ -34,10 +33,11 @@ except ImportError:
logger = logging.getLogger("unity-mcp-telemetry") logger = logging.getLogger("unity-mcp-telemetry")
class RecordType(str, Enum): class RecordType(str, Enum):
"""Types of telemetry records we collect""" """Types of telemetry records we collect"""
VERSION = "version" VERSION = "version"
STARTUP = "startup" STARTUP = "startup"
USAGE = "usage" USAGE = "usage"
LATENCY = "latency" LATENCY = "latency"
FAILURE = "failure" FAILURE = "failure"
@ -45,6 +45,7 @@ class RecordType(str, Enum):
UNITY_CONNECTION = "unity_connection" UNITY_CONNECTION = "unity_connection"
CLIENT_CONNECTION = "client_connection" CLIENT_CONNECTION = "client_connection"
class MilestoneType(str, Enum): class MilestoneType(str, Enum):
"""Major user journey milestones""" """Major user journey milestones"""
FIRST_STARTUP = "first_startup" FIRST_STARTUP = "first_startup"
@ -55,6 +56,7 @@ class MilestoneType(str, Enum):
DAILY_ACTIVE_USER = "daily_active_user" DAILY_ACTIVE_USER = "daily_active_user"
WEEKLY_ACTIVE_USER = "weekly_active_user" WEEKLY_ACTIVE_USER = "weekly_active_user"
@dataclass @dataclass
class TelemetryRecord: class TelemetryRecord:
"""Structure for telemetry data""" """Structure for telemetry data"""
@ -65,8 +67,10 @@ class TelemetryRecord:
data: Dict[str, Any] data: Dict[str, Any]
milestone: Optional[MilestoneType] = None milestone: Optional[MilestoneType] = None
class TelemetryConfig: class TelemetryConfig:
"""Telemetry configuration""" """Telemetry configuration"""
def __init__(self): def __init__(self):
# Prefer config file, then allow env overrides # Prefer config file, then allow env overrides
server_config = None server_config = None
@ -85,11 +89,13 @@ class TelemetryConfig:
continue continue
# Determine enabled flag: config -> env DISABLE_* opt-out # Determine enabled flag: config -> env DISABLE_* opt-out
cfg_enabled = True if server_config is None else bool(getattr(server_config, "telemetry_enabled", True)) cfg_enabled = True if server_config is None else bool(
getattr(server_config, "telemetry_enabled", True))
self.enabled = cfg_enabled and not self._is_disabled() self.enabled = cfg_enabled and not self._is_disabled()
# Telemetry endpoint (Cloud Run default; override via env) # Telemetry endpoint (Cloud Run default; override via env)
cfg_default = None if server_config is None else getattr(server_config, "telemetry_endpoint", None) cfg_default = None if server_config is None else getattr(
server_config, "telemetry_endpoint", None)
default_ep = cfg_default or "https://api-prod.coplay.dev/telemetry/events" default_ep = cfg_default or "https://api-prod.coplay.dev/telemetry/events"
self.default_endpoint = default_ep self.default_endpoint = default_ep
self.endpoint = self._validated_endpoint( self.endpoint = self._validated_endpoint(
@ -105,50 +111,53 @@ class TelemetryConfig:
) )
except Exception: except Exception:
pass pass
# Local storage for UUID and milestones # Local storage for UUID and milestones
self.data_dir = self._get_data_directory() self.data_dir = self._get_data_directory()
self.uuid_file = self.data_dir / "customer_uuid.txt" self.uuid_file = self.data_dir / "customer_uuid.txt"
self.milestones_file = self.data_dir / "milestones.json" self.milestones_file = self.data_dir / "milestones.json"
# Request timeout (small, fail fast). Override with UNITY_MCP_TELEMETRY_TIMEOUT # Request timeout (small, fail fast). Override with UNITY_MCP_TELEMETRY_TIMEOUT
try: try:
self.timeout = float(os.environ.get("UNITY_MCP_TELEMETRY_TIMEOUT", "1.5")) self.timeout = float(os.environ.get(
"UNITY_MCP_TELEMETRY_TIMEOUT", "1.5"))
except Exception: except Exception:
self.timeout = 1.5 self.timeout = 1.5
try: try:
logger.info("Telemetry timeout=%.2fs", self.timeout) logger.info("Telemetry timeout=%.2fs", self.timeout)
except Exception: except Exception:
pass pass
# Session tracking # Session tracking
self.session_id = str(uuid.uuid4()) self.session_id = str(uuid.uuid4())
def _is_disabled(self) -> bool: def _is_disabled(self) -> bool:
"""Check if telemetry is disabled via environment variables""" """Check if telemetry is disabled via environment variables"""
disable_vars = [ disable_vars = [
"DISABLE_TELEMETRY", "DISABLE_TELEMETRY",
"UNITY_MCP_DISABLE_TELEMETRY", "UNITY_MCP_DISABLE_TELEMETRY",
"MCP_DISABLE_TELEMETRY" "MCP_DISABLE_TELEMETRY"
] ]
for var in disable_vars: for var in disable_vars:
if os.environ.get(var, "").lower() in ("true", "1", "yes", "on"): if os.environ.get(var, "").lower() in ("true", "1", "yes", "on"):
return True return True
return False return False
def _get_data_directory(self) -> Path: def _get_data_directory(self) -> Path:
"""Get directory for storing telemetry data""" """Get directory for storing telemetry data"""
if os.name == 'nt': # Windows if os.name == 'nt': # Windows
base_dir = Path(os.environ.get('APPDATA', Path.home() / 'AppData' / 'Roaming')) base_dir = Path(os.environ.get(
'APPDATA', Path.home() / 'AppData' / 'Roaming'))
elif os.name == 'posix': # macOS/Linux elif os.name == 'posix': # macOS/Linux
if 'darwin' in os.uname().sysname.lower(): # macOS if 'darwin' in os.uname().sysname.lower(): # macOS
base_dir = Path.home() / 'Library' / 'Application Support' base_dir = Path.home() / 'Library' / 'Application Support'
else: # Linux else: # Linux
base_dir = Path(os.environ.get('XDG_DATA_HOME', Path.home() / '.local' / 'share')) base_dir = Path(os.environ.get('XDG_DATA_HOME',
Path.home() / '.local' / 'share'))
else: else:
base_dir = Path.home() / '.unity-mcp' base_dir = Path.home() / '.unity-mcp'
data_dir = base_dir / 'UnityMCP' data_dir = base_dir / 'UnityMCP'
data_dir.mkdir(parents=True, exist_ok=True) data_dir.mkdir(parents=True, exist_ok=True)
return data_dir return data_dir
@ -167,7 +176,8 @@ class TelemetryConfig:
# Reject localhost/loopback endpoints in production to avoid accidental local overrides # Reject localhost/loopback endpoints in production to avoid accidental local overrides
host = parsed.hostname or "" host = parsed.hostname or ""
if host in ("localhost", "127.0.0.1", "::1"): if host in ("localhost", "127.0.0.1", "::1"):
raise ValueError("Localhost endpoints are not allowed for telemetry") raise ValueError(
"Localhost endpoints are not allowed for telemetry")
return candidate return candidate
except Exception as e: except Exception as e:
logger.debug( logger.debug(
@ -176,9 +186,10 @@ class TelemetryConfig:
) )
return fallback return fallback
class TelemetryCollector: class TelemetryCollector:
"""Main telemetry collection class""" """Main telemetry collection class"""
def __init__(self): def __init__(self):
self.config = TelemetryConfig() self.config = TelemetryConfig()
self._customer_uuid: Optional[str] = None self._customer_uuid: Optional[str] = None
@ -188,23 +199,27 @@ class TelemetryCollector:
self._queue: "queue.Queue[TelemetryRecord]" = queue.Queue(maxsize=1000) self._queue: "queue.Queue[TelemetryRecord]" = queue.Queue(maxsize=1000)
# Load persistent data before starting worker so first events have UUID # Load persistent data before starting worker so first events have UUID
self._load_persistent_data() self._load_persistent_data()
self._worker: threading.Thread = threading.Thread(target=self._worker_loop, daemon=True) self._worker: threading.Thread = threading.Thread(
target=self._worker_loop, daemon=True)
self._worker.start() self._worker.start()
def _load_persistent_data(self): def _load_persistent_data(self):
"""Load UUID and milestones from disk""" """Load UUID and milestones from disk"""
# Load customer UUID # Load customer UUID
try: try:
if self.config.uuid_file.exists(): if self.config.uuid_file.exists():
self._customer_uuid = self.config.uuid_file.read_text(encoding="utf-8").strip() or str(uuid.uuid4()) self._customer_uuid = self.config.uuid_file.read_text(
encoding="utf-8").strip() or str(uuid.uuid4())
else: else:
self._customer_uuid = str(uuid.uuid4()) self._customer_uuid = str(uuid.uuid4())
try: try:
self.config.uuid_file.write_text(self._customer_uuid, encoding="utf-8") self.config.uuid_file.write_text(
self._customer_uuid, encoding="utf-8")
if os.name == "posix": if os.name == "posix":
os.chmod(self.config.uuid_file, 0o600) os.chmod(self.config.uuid_file, 0o600)
except OSError as e: except OSError as e:
logger.debug(f"Failed to persist customer UUID: {e}", exc_info=True) logger.debug(
f"Failed to persist customer UUID: {e}", exc_info=True)
except OSError as e: except OSError as e:
logger.debug(f"Failed to load customer UUID: {e}", exc_info=True) logger.debug(f"Failed to load customer UUID: {e}", exc_info=True)
self._customer_uuid = str(uuid.uuid4()) self._customer_uuid = str(uuid.uuid4())
@ -212,14 +227,15 @@ class TelemetryCollector:
# Load milestones (failure here must not affect UUID) # Load milestones (failure here must not affect UUID)
try: try:
if self.config.milestones_file.exists(): if self.config.milestones_file.exists():
content = self.config.milestones_file.read_text(encoding="utf-8") content = self.config.milestones_file.read_text(
encoding="utf-8")
self._milestones = json.loads(content) or {} self._milestones = json.loads(content) or {}
if not isinstance(self._milestones, dict): if not isinstance(self._milestones, dict):
self._milestones = {} self._milestones = {}
except (OSError, json.JSONDecodeError, ValueError) as e: except (OSError, json.JSONDecodeError, ValueError) as e:
logger.debug(f"Failed to load milestones: {e}", exc_info=True) logger.debug(f"Failed to load milestones: {e}", exc_info=True)
self._milestones = {} self._milestones = {}
def _save_milestones(self): def _save_milestones(self):
"""Save milestones to disk. Caller must hold self._lock.""" """Save milestones to disk. Caller must hold self._lock."""
try: try:
@ -229,7 +245,7 @@ class TelemetryCollector:
) )
except OSError as e: except OSError as e:
logger.warning(f"Failed to save milestones: {e}", exc_info=True) logger.warning(f"Failed to save milestones: {e}", exc_info=True)
def record_milestone(self, milestone: MilestoneType, data: Optional[Dict[str, Any]] = None) -> bool: def record_milestone(self, milestone: MilestoneType, data: Optional[Dict[str, Any]] = None) -> bool:
"""Record a milestone event, returns True if this is the first occurrence""" """Record a milestone event, returns True if this is the first occurrence"""
if not self.config.enabled: if not self.config.enabled:
@ -244,26 +260,26 @@ class TelemetryCollector:
} }
self._milestones[milestone_key] = milestone_data self._milestones[milestone_key] = milestone_data
self._save_milestones() self._save_milestones()
# Also send as telemetry record # Also send as telemetry record
self.record( self.record(
record_type=RecordType.USAGE, record_type=RecordType.USAGE,
data={"milestone": milestone_key, **(data or {})}, data={"milestone": milestone_key, **(data or {})},
milestone=milestone milestone=milestone
) )
return True return True
def record(self, def record(self,
record_type: RecordType, record_type: RecordType,
data: Dict[str, Any], data: Dict[str, Any],
milestone: Optional[MilestoneType] = None): milestone: Optional[MilestoneType] = None):
"""Record a telemetry event (async, non-blocking)""" """Record a telemetry event (async, non-blocking)"""
if not self.config.enabled: if not self.config.enabled:
return return
# Allow fallback sender when httpx is unavailable (no early return) # Allow fallback sender when httpx is unavailable (no early return)
record = TelemetryRecord( record = TelemetryRecord(
record_type=record_type, record_type=record_type,
timestamp=time.time(), timestamp=time.time(),
@ -276,7 +292,8 @@ class TelemetryCollector:
try: try:
self._queue.put_nowait(record) self._queue.put_nowait(record)
except queue.Full: except queue.Full:
logger.debug("Telemetry queue full; dropping %s", record.record_type) logger.debug("Telemetry queue full; dropping %s",
record.record_type)
def _worker_loop(self): def _worker_loop(self):
"""Background worker that serializes telemetry sends.""" """Background worker that serializes telemetry sends."""
@ -290,7 +307,7 @@ class TelemetryCollector:
finally: finally:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self._queue.task_done() self._queue.task_done()
def _send_telemetry(self, record: TelemetryRecord): def _send_telemetry(self, record: TelemetryRecord):
"""Send telemetry data to endpoint""" """Send telemetry data to endpoint"""
try: try:
@ -323,17 +340,20 @@ class TelemetryCollector:
if httpx: if httpx:
with httpx.Client(timeout=self.config.timeout) as client: with httpx.Client(timeout=self.config.timeout) as client:
# Re-validate endpoint at send time to handle dynamic changes # Re-validate endpoint at send time to handle dynamic changes
endpoint = self.config._validated_endpoint(self.config.endpoint, self.config.default_endpoint) endpoint = self.config._validated_endpoint(
self.config.endpoint, self.config.default_endpoint)
response = client.post(endpoint, json=payload) response = client.post(endpoint, json=payload)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
logger.debug(f"Telemetry sent: {record.record_type}") logger.debug(f"Telemetry sent: {record.record_type}")
else: else:
logger.warning(f"Telemetry failed: HTTP {response.status_code}") logger.warning(
f"Telemetry failed: HTTP {response.status_code}")
else: else:
import urllib.request import urllib.request
import urllib.error import urllib.error
data_bytes = json.dumps(payload).encode("utf-8") data_bytes = json.dumps(payload).encode("utf-8")
endpoint = self.config._validated_endpoint(self.config.endpoint, self.config.default_endpoint) endpoint = self.config._validated_endpoint(
self.config.endpoint, self.config.default_endpoint)
req = urllib.request.Request( req = urllib.request.Request(
endpoint, endpoint,
data=data_bytes, data=data_bytes,
@ -343,9 +363,11 @@ class TelemetryCollector:
try: try:
with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: with urllib.request.urlopen(req, timeout=self.config.timeout) as resp:
if 200 <= resp.getcode() < 300: if 200 <= resp.getcode() < 300:
logger.debug(f"Telemetry sent (urllib): {record.record_type}") logger.debug(
f"Telemetry sent (urllib): {record.record_type}")
else: else:
logger.warning(f"Telemetry failed (urllib): HTTP {resp.getcode()}") logger.warning(
f"Telemetry failed (urllib): HTTP {resp.getcode()}")
except urllib.error.URLError as ue: except urllib.error.URLError as ue:
logger.warning(f"Telemetry send failed (urllib): {ue}") logger.warning(f"Telemetry send failed (urllib): {ue}")
@ -357,6 +379,7 @@ class TelemetryCollector:
# Global telemetry instance # Global telemetry instance
_telemetry_collector: Optional[TelemetryCollector] = None _telemetry_collector: Optional[TelemetryCollector] = None
def get_telemetry() -> TelemetryCollector: def get_telemetry() -> TelemetryCollector:
"""Get the global telemetry collector instance""" """Get the global telemetry collector instance"""
global _telemetry_collector global _telemetry_collector
@ -364,16 +387,19 @@ def get_telemetry() -> TelemetryCollector:
_telemetry_collector = TelemetryCollector() _telemetry_collector = TelemetryCollector()
return _telemetry_collector return _telemetry_collector
def record_telemetry(record_type: RecordType,
data: Dict[str, Any], def record_telemetry(record_type: RecordType,
milestone: Optional[MilestoneType] = None): data: Dict[str, Any],
milestone: Optional[MilestoneType] = None):
"""Convenience function to record telemetry""" """Convenience function to record telemetry"""
get_telemetry().record(record_type, data, milestone) get_telemetry().record(record_type, data, milestone)
def record_milestone(milestone: MilestoneType, data: Optional[Dict[str, Any]] = None) -> bool: def record_milestone(milestone: MilestoneType, data: Optional[Dict[str, Any]] = None) -> bool:
"""Convenience function to record a milestone""" """Convenience function to record a milestone"""
return get_telemetry().record_milestone(milestone, data) return get_telemetry().record_milestone(milestone, data)
def record_tool_usage(tool_name: str, success: bool, duration_ms: float, error: Optional[str] = None, sub_action: Optional[str] = None): def record_tool_usage(tool_name: str, success: bool, duration_ms: float, error: Optional[str] = None, sub_action: Optional[str] = None):
"""Record tool usage telemetry """Record tool usage telemetry
@ -396,36 +422,39 @@ def record_tool_usage(tool_name: str, success: bool, duration_ms: float, error:
except Exception: except Exception:
# Ensure telemetry is never disruptive # Ensure telemetry is never disruptive
data["sub_action"] = "unknown" data["sub_action"] = "unknown"
if error: if error:
data["error"] = str(error)[:200] # Limit error message length data["error"] = str(error)[:200] # Limit error message length
record_telemetry(RecordType.TOOL_EXECUTION, data) record_telemetry(RecordType.TOOL_EXECUTION, data)
def record_latency(operation: str, duration_ms: float, metadata: Optional[Dict[str, Any]] = None): def record_latency(operation: str, duration_ms: float, metadata: Optional[Dict[str, Any]] = None):
"""Record latency telemetry""" """Record latency telemetry"""
data = { data = {
"operation": operation, "operation": operation,
"duration_ms": round(duration_ms, 2) "duration_ms": round(duration_ms, 2)
} }
if metadata: if metadata:
data.update(metadata) data.update(metadata)
record_telemetry(RecordType.LATENCY, data) record_telemetry(RecordType.LATENCY, data)
def record_failure(component: str, error: str, metadata: Optional[Dict[str, Any]] = None): def record_failure(component: str, error: str, metadata: Optional[Dict[str, Any]] = None):
"""Record failure telemetry""" """Record failure telemetry"""
data = { data = {
"component": component, "component": component,
"error": str(error)[:500] # Limit error message length "error": str(error)[:500] # Limit error message length
} }
if metadata: if metadata:
data.update(metadata) data.update(metadata)
record_telemetry(RecordType.FAILURE, data) record_telemetry(RecordType.FAILURE, data)
def is_telemetry_enabled() -> bool: def is_telemetry_enabled() -> bool:
"""Check if telemetry is enabled""" """Check if telemetry is enabled"""
return get_telemetry().config.enabled return get_telemetry().config.enabled

View File

@ -3,15 +3,17 @@ Telemetry decorator for Unity MCP tools
""" """
import functools import functools
import time
import inspect import inspect
import logging import logging
import time
from typing import Callable, Any from typing import Callable, Any
from telemetry import record_tool_usage, record_milestone, MilestoneType from telemetry import record_tool_usage, record_milestone, MilestoneType
_log = logging.getLogger("unity-mcp-telemetry") _log = logging.getLogger("unity-mcp-telemetry")
_decorator_log_count = 0 _decorator_log_count = 0
def telemetry_tool(tool_name: str): def telemetry_tool(tool_name: str):
"""Decorator to add telemetry tracking to MCP tools""" """Decorator to add telemetry tracking to MCP tools"""
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@ -41,7 +43,8 @@ def telemetry_tool(tool_name: str):
if tool_name == "manage_script" and action_val == "create": if tool_name == "manage_script" and action_val == "create":
record_milestone(MilestoneType.FIRST_SCRIPT_CREATION) record_milestone(MilestoneType.FIRST_SCRIPT_CREATION)
elif tool_name.startswith("manage_scene"): elif tool_name.startswith("manage_scene"):
record_milestone(MilestoneType.FIRST_SCENE_MODIFICATION) record_milestone(
MilestoneType.FIRST_SCENE_MODIFICATION)
record_milestone(MilestoneType.FIRST_TOOL_USAGE) record_milestone(MilestoneType.FIRST_TOOL_USAGE)
except Exception: except Exception:
_log.debug("milestone emit failed", exc_info=True) _log.debug("milestone emit failed", exc_info=True)
@ -52,7 +55,8 @@ def telemetry_tool(tool_name: str):
finally: finally:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
try: try:
record_tool_usage(tool_name, success, duration_ms, error, sub_action=sub_action) record_tool_usage(tool_name, success,
duration_ms, error, sub_action=sub_action)
except Exception: except Exception:
_log.debug("record_tool_usage failed", exc_info=True) _log.debug("record_tool_usage failed", exc_info=True)
@ -82,7 +86,8 @@ def telemetry_tool(tool_name: str):
if tool_name == "manage_script" and action_val == "create": if tool_name == "manage_script" and action_val == "create":
record_milestone(MilestoneType.FIRST_SCRIPT_CREATION) record_milestone(MilestoneType.FIRST_SCRIPT_CREATION)
elif tool_name.startswith("manage_scene"): elif tool_name.startswith("manage_scene"):
record_milestone(MilestoneType.FIRST_SCENE_MODIFICATION) record_milestone(
MilestoneType.FIRST_SCENE_MODIFICATION)
record_milestone(MilestoneType.FIRST_TOOL_USAGE) record_milestone(MilestoneType.FIRST_TOOL_USAGE)
except Exception: except Exception:
_log.debug("milestone emit failed", exc_info=True) _log.debug("milestone emit failed", exc_info=True)
@ -93,9 +98,10 @@ def telemetry_tool(tool_name: str):
finally: finally:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
try: try:
record_tool_usage(tool_name, success, duration_ms, error, sub_action=sub_action) record_tool_usage(tool_name, success,
duration_ms, error, sub_action=sub_action)
except Exception: except Exception:
_log.debug("record_tool_usage failed", exc_info=True) _log.debug("record_tool_usage failed", exc_info=True)
return _async_wrapper if inspect.iscoroutinefunction(func) else _sync_wrapper return _async_wrapper if inspect.iscoroutinefunction(func) else _sync_wrapper
return decorator return decorator

View File

@ -5,30 +5,30 @@ Run this to verify telemetry is working correctly
""" """
import os import os
import time
import sys
from pathlib import Path from pathlib import Path
import sys
# Add src to Python path for imports # Add src to Python path for imports
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
def test_telemetry_basic(): def test_telemetry_basic():
"""Test basic telemetry functionality""" """Test basic telemetry functionality"""
# Avoid stdout noise in tests # Avoid stdout noise in tests
try: try:
from telemetry import ( from telemetry import (
get_telemetry, record_telemetry, record_milestone, get_telemetry, record_telemetry, record_milestone,
RecordType, MilestoneType, is_telemetry_enabled RecordType, MilestoneType, is_telemetry_enabled
) )
pass pass
except ImportError as e: except ImportError as e:
# Silent failure path for tests # Silent failure path for tests
return False return False
# Test telemetry enabled status # Test telemetry enabled status
_ = is_telemetry_enabled() _ = is_telemetry_enabled()
# Test basic record # Test basic record
try: try:
record_telemetry(RecordType.VERSION, { record_telemetry(RecordType.VERSION, {
@ -39,7 +39,7 @@ def test_telemetry_basic():
except Exception as e: except Exception as e:
# Silent failure path for tests # Silent failure path for tests
return False return False
# Test milestone recording # Test milestone recording
try: try:
is_first = record_milestone(MilestoneType.FIRST_STARTUP, { is_first = record_milestone(MilestoneType.FIRST_STARTUP, {
@ -49,7 +49,7 @@ def test_telemetry_basic():
except Exception as e: except Exception as e:
# Silent failure path for tests # Silent failure path for tests
return False return False
# Test telemetry collector # Test telemetry collector
try: try:
collector = get_telemetry() collector = get_telemetry()
@ -57,79 +57,83 @@ def test_telemetry_basic():
except Exception as e: except Exception as e:
# Silent failure path for tests # Silent failure path for tests
return False return False
return True return True
def test_telemetry_disabled(): def test_telemetry_disabled():
"""Test telemetry with disabled state""" """Test telemetry with disabled state"""
# Silent for tests # Silent for tests
# Set environment variable to disable telemetry # Set environment variable to disable telemetry
os.environ["DISABLE_TELEMETRY"] = "true" os.environ["DISABLE_TELEMETRY"] = "true"
# Re-import to get fresh config # Re-import to get fresh config
import importlib import importlib
import telemetry import telemetry
importlib.reload(telemetry) importlib.reload(telemetry)
from telemetry import is_telemetry_enabled, record_telemetry, RecordType from telemetry import is_telemetry_enabled, record_telemetry, RecordType
_ = is_telemetry_enabled() _ = is_telemetry_enabled()
if not is_telemetry_enabled(): if not is_telemetry_enabled():
pass pass
# Test that records are ignored when disabled # Test that records are ignored when disabled
record_telemetry(RecordType.USAGE, {"test": "should_be_ignored"}) record_telemetry(RecordType.USAGE, {"test": "should_be_ignored"})
pass pass
return True return True
else: else:
pass pass
return False return False
def test_data_storage(): def test_data_storage():
"""Test data storage functionality""" """Test data storage functionality"""
# Silent for tests # Silent for tests
try: try:
from telemetry import get_telemetry from telemetry import get_telemetry
collector = get_telemetry() collector = get_telemetry()
data_dir = collector.config.data_dir data_dir = collector.config.data_dir
_ = (data_dir, collector.config.uuid_file, collector.config.milestones_file) _ = (data_dir, collector.config.uuid_file,
collector.config.milestones_file)
# Check if files exist # Check if files exist
if collector.config.uuid_file.exists(): if collector.config.uuid_file.exists():
pass pass
else: else:
pass pass
if collector.config.milestones_file.exists(): if collector.config.milestones_file.exists():
pass pass
else: else:
pass pass
return True return True
except Exception as e: except Exception as e:
# Silent failure path for tests # Silent failure path for tests
return False return False
def main(): def main():
"""Run all telemetry tests""" """Run all telemetry tests"""
# Silent runner for CI # Silent runner for CI
tests = [ tests = [
test_telemetry_basic, test_telemetry_basic,
test_data_storage, test_data_storage,
test_telemetry_disabled, test_telemetry_disabled,
] ]
passed = 0 passed = 0
failed = 0 failed = 0
for test in tests: for test in tests:
try: try:
if test(): if test():
@ -141,9 +145,9 @@ def main():
except Exception as e: except Exception as e:
failed += 1 failed += 1
pass pass
_ = (passed, failed) _ = (passed, failed)
if failed == 0: if failed == 0:
pass pass
return True return True
@ -151,6 +155,7 @@ def main():
pass pass
return False return False
if __name__ == "__main__": if __name__ == "__main__":
success = main() success = main()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)

View File

@ -1,4 +1,7 @@
import logging import logging
from mcp.server.fastmcp import FastMCP
from .manage_script_edits import register_manage_script_edits_tools from .manage_script_edits import register_manage_script_edits_tools
from .manage_script import register_manage_script_tools from .manage_script import register_manage_script_tools
from .manage_scene import register_manage_scene_tools from .manage_scene import register_manage_scene_tools
@ -13,7 +16,8 @@ from .resource_tools import register_resource_tools
logger = logging.getLogger("mcp-for-unity-server") logger = logging.getLogger("mcp-for-unity-server")
def register_all_tools(mcp):
def register_all_tools(mcp: FastMCP):
"""Register all refactored tools with the MCP server.""" """Register all refactored tools with the MCP server."""
# Prefer the surgical edits tool so LLMs discover it first # Prefer the surgical edits tool so LLMs discover it first
logger.info("Registering MCP for Unity Server refactored tools...") logger.info("Registering MCP for Unity Server refactored tools...")

View File

@ -1,58 +1,45 @@
""" """
Defines the manage_asset tool for interacting with Unity assets. Defines the manage_asset tool for interacting with Unity assets.
""" """
import asyncio # Added: Import asyncio for running sync code in async import asyncio
from typing import Dict, Any from typing import Annotated, Any, Literal
from mcp.server.fastmcp import FastMCP, Context
# from ..unity_connection import get_unity_connection # Original line that caused error
from unity_connection import get_unity_connection, async_send_command_with_retry # Use centralized retry helper
from config import config
import time
from mcp.server.fastmcp import FastMCP, Context
from unity_connection import async_send_command_with_retry
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
def register_manage_asset_tools(mcp: FastMCP): def register_manage_asset_tools(mcp: FastMCP):
"""Registers the manage_asset tool with the MCP server.""" """Registers the manage_asset tool with the MCP server."""
@mcp.tool() @mcp.tool(name="manage_asset", description="Performs asset operations (import, create, modify, delete, etc.) in Unity.")
@telemetry_tool("manage_asset") @telemetry_tool("manage_asset")
async def manage_asset( async def manage_asset(
ctx: Any, ctx: Context,
action: str, action: Annotated[Literal["import", "create", "modify", "delete", "duplicate", "move", "rename", "search", "get_info", "create_folder", "get_components"], "Perform CRUD operations on assets."],
path: str, path: Annotated[str, "Asset path (e.g., 'Materials/MyMaterial.mat') or search scope."],
asset_type: str = None, asset_type: Annotated[str,
properties: Dict[str, Any] = None, "Asset type (e.g., 'Material', 'Folder') - required for 'create'."] | None = None,
destination: str = None, properties: Annotated[dict[str, Any],
generate_preview: bool = False, "Dictionary of properties for 'create'/'modify'."] | None = None,
search_pattern: str = None, destination: Annotated[str,
filter_type: str = None, "Target path for 'duplicate'/'move'."] | None = None,
filter_date_after: str = None, generate_preview: Annotated[bool,
page_size: Any = None, "Generate a preview/thumbnail for the asset when supported."] = False,
page_number: Any = None search_pattern: Annotated[str,
) -> Dict[str, Any]: "Search pattern (e.g., '*.prefab')."] | None = None,
"""Performs asset operations (import, create, modify, delete, etc.) in Unity. filter_type: Annotated[str, "Filter type for search"] | None = None,
filter_date_after: Annotated[str,
Args: "Date after which to filter"] | None = None,
ctx: The MCP context. page_size: Annotated[int, "Page size for pagination"] | None = None,
action: Operation to perform (e.g., 'import', 'create', 'modify', 'delete', 'duplicate', 'move', 'rename', 'search', 'get_info', 'create_folder', 'get_components'). page_number: Annotated[int, "Page number for pagination"] | None = None
path: Asset path (e.g., "Materials/MyMaterial.mat") or search scope. ) -> dict[str, Any]:
asset_type: Asset type (e.g., 'Material', 'Folder') - required for 'create'. ctx.info(f"Processing manage_asset: {action}")
properties: Dictionary of properties for 'create'/'modify'.
example properties for Material: {"color": [1, 0, 0, 1], "shader": "Standard"}.
example properties for Texture: {"width": 1024, "height": 1024, "format": "RGBA32"}.
example properties for PhysicsMaterial: {"bounciness": 1.0, "staticFriction": 0.5, "dynamicFriction": 0.5}.
destination: Target path for 'duplicate'/'move'.
search_pattern: Search pattern (e.g., '*.prefab').
filter_*: Filters for search (type, date).
page_*: Pagination for search.
Returns:
A dictionary with operation results ('success', 'data', 'error').
"""
# Ensure properties is a dict if None # Ensure properties is a dict if None
if properties is None: if properties is None:
properties = {} properties = {}
# Coerce numeric inputs defensively # Coerce numeric inputs defensively
def _coerce_int(value, default=None): def _coerce_int(value, default=None):
if value is None: if value is None:
@ -86,15 +73,13 @@ def register_manage_asset_tools(mcp: FastMCP):
"pageSize": page_size, "pageSize": page_size,
"pageNumber": page_number "pageNumber": page_number
} }
# Remove None values to avoid sending unnecessary nulls # Remove None values to avoid sending unnecessary nulls
params_dict = {k: v for k, v in params_dict.items() if v is not None} params_dict = {k: v for k, v in params_dict.items() if v is not None}
# Get the current asyncio event loop # Get the current asyncio event loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Get the Unity connection instance
connection = get_unity_connection()
# Use centralized async retry helper to avoid blocking the event loop # Use centralized async retry helper to avoid blocking the event loop
result = await async_send_command_with_retry("manage_asset", params_dict, loop=loop) result = await async_send_command_with_retry("manage_asset", params_dict, loop=loop)
# Return the result obtained from Unity # Return the result obtained from Unity

View File

@ -1,37 +1,31 @@
from mcp.server.fastmcp import FastMCP, Context from typing import Annotated, Any, Literal
import time
from typing import Dict, Any
from unity_connection import get_unity_connection, send_command_with_retry
from config import config
from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from telemetry import is_telemetry_enabled, record_tool_usage from telemetry import is_telemetry_enabled, record_tool_usage
from unity_connection import send_command_with_retry
def register_manage_editor_tools(mcp: FastMCP): def register_manage_editor_tools(mcp: FastMCP):
"""Register all editor management tools with the MCP server.""" """Register all editor management tools with the MCP server."""
@mcp.tool(description=( @mcp.tool(name="manage_editor", description="Controls and queries the Unity editor's state and settings")
"Controls and queries the Unity editor's state and settings.\n\n"
"Args:\n"
"- ctx: Context object (required)\n"
"- action: Operation (e.g., 'play', 'pause', 'get_state', 'set_active_tool', 'add_tag')\n"
"- wait_for_completion: Optional. If True, waits for certain actions\n"
"- tool_name: Tool name for specific actions\n"
"- tag_name: Tag name for specific actions\n"
"- layer_name: Layer name for specific actions\n\n"
"Returns:\n"
"Dictionary with operation results ('success', 'message', 'data')."
))
@telemetry_tool("manage_editor") @telemetry_tool("manage_editor")
def manage_editor( def manage_editor(
ctx: Context, ctx: Context,
action: str, action: Annotated[Literal["telemetry_status", "telemetry_ping", "play", "pause", "stop", "get_state", "get_project_root", "get_windows",
wait_for_completion: bool = None, "get_active_tool", "get_selection", "get_prefab_stage", "set_active_tool", "add_tag", "remove_tag", "get_tags", "add_layer", "remove_layer", "get_layers"], "Get and update the Unity Editor state."],
# --- Parameters for specific actions --- wait_for_completion: Annotated[bool,
tool_name: str = None, "Optional. If True, waits for certain actions"] | None = None,
tag_name: str = None, tool_name: Annotated[str,
layer_name: str = None, "Tool name when setting active tool"] | None = None,
) -> Dict[str, Any]: tag_name: Annotated[str,
"Tag name when adding and removing tags"] | None = None,
layer_name: Annotated[str,
"Layer name when adding and removing layers"] | None = None,
) -> dict[str, Any]:
ctx.info(f"Processing manage_editor: {action}")
try: try:
# Diagnostics: quick telemetry checks # Diagnostics: quick telemetry checks
if action == "telemetry_status": if action == "telemetry_status":
@ -44,16 +38,16 @@ def register_manage_editor_tools(mcp: FastMCP):
params = { params = {
"action": action, "action": action,
"waitForCompletion": wait_for_completion, "waitForCompletion": wait_for_completion,
"toolName": tool_name, # Corrected parameter name to match C# "toolName": tool_name, # Corrected parameter name to match C#
"tagName": tag_name, # Pass tag name "tagName": tag_name, # Pass tag name
"layerName": layer_name, # Pass layer name "layerName": layer_name, # Pass layer name
# Add other parameters based on the action being performed # Add other parameters based on the action being performed
# "width": width, # "width": width,
# "height": height, # "height": height,
# etc. # etc.
} }
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
# Send command using centralized retry helper # Send command using centralized retry helper
response = send_command_with_retry("manage_editor", params) response = send_command_with_retry("manage_editor", params)
@ -63,4 +57,4 @@ def register_manage_editor_tools(mcp: FastMCP):
return response if isinstance(response, dict) else {"success": False, "message": str(response)} return response if isinstance(response, dict) else {"success": False, "message": str(response)}
except Exception as e: except Exception as e:
return {"success": False, "message": f"Python error managing editor: {str(e)}"} return {"success": False, "message": f"Python error managing editor: {str(e)}"}

View File

@ -1,87 +1,74 @@
from mcp.server.fastmcp import FastMCP, Context from typing import Annotated, Any, Literal
from typing import Dict, Any, List
from unity_connection import get_unity_connection, send_command_with_retry
from config import config
import time
from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry
def register_manage_gameobject_tools(mcp: FastMCP): def register_manage_gameobject_tools(mcp: FastMCP):
"""Register all GameObject management tools with the MCP server.""" """Register all GameObject management tools with the MCP server."""
@mcp.tool() @mcp.tool(name="manage_gameobject", description="Manage GameObjects. Note: for 'get_components', the `data` field contains a dictionary of component names and their serialized properties.")
@telemetry_tool("manage_gameobject") @telemetry_tool("manage_gameobject")
def manage_gameobject( def manage_gameobject(
ctx: Any, ctx: Context,
action: str, action: Annotated[Literal["create", "modify", "delete", "find", "add_component", "remove_component", "set_component_property", "get_components"], "Perform CRUD operations on GameObjects and components."],
target: str = None, # GameObject identifier by name or path target: Annotated[str,
search_method: str = None, "GameObject identifier by name or path for modify/delete/component actions"] | None = None,
# --- Combined Parameters for Create/Modify --- search_method: Annotated[str,
name: str = None, # Used for both 'create' (new object name) and 'modify' (rename) "How to find objects ('by_name', 'by_id', 'by_path', etc.). Used with 'find' and some 'target' lookups."] | None = None,
tag: str = None, # Used for both 'create' (initial tag) and 'modify' (change tag) name: Annotated[str,
parent: str = None, # Used for both 'create' (initial parent) and 'modify' (change parent) "GameObject name - used for both 'create' (initial name) and 'modify' (rename)"] | None = None,
position: List[float] = None, tag: Annotated[str,
rotation: List[float] = None, "Tag name - used for both 'create' (initial tag) and 'modify' (change tag)"] | None = None,
scale: List[float] = None, parent: Annotated[str,
components_to_add: List[str] = None, # List of component names to add "Parent GameObject reference - used for both 'create' (initial parent) and 'modify' (change parent)"] | None = None,
primitive_type: str = None, position: Annotated[list[float],
save_as_prefab: bool = False, "Position - used for both 'create' (initial position) and 'modify' (change position)"] | None = None,
prefab_path: str = None, rotation: Annotated[list[float],
prefab_folder: str = "Assets/Prefabs", "Rotation - used for both 'create' (initial rotation) and 'modify' (change rotation)"] | None = None,
scale: Annotated[list[float],
"Scale - used for both 'create' (initial scale) and 'modify' (change scale)"] | None = None,
components_to_add: Annotated[list[str],
"List of component names to add"] | None = None,
primitive_type: Annotated[str,
"Primitive type for 'create' action"] | None = None,
save_as_prefab: Annotated[bool,
"If True, saves the created GameObject as a prefab"] | None = None,
prefab_path: Annotated[str, "Path for prefab creation"] | None = None,
prefab_folder: Annotated[str,
"Folder for prefab creation"] | None = None,
# --- Parameters for 'modify' --- # --- Parameters for 'modify' ---
set_active: bool = None, set_active: Annotated[bool,
layer: str = None, # Layer name "If True, sets the GameObject active"] | None = None,
components_to_remove: List[str] = None, layer: Annotated[str, "Layer name"] | None = None,
component_properties: Dict[str, Dict[str, Any]] = None, components_to_remove: Annotated[list[str],
"List of component names to remove"] | None = None,
component_properties: Annotated[dict[str, dict[str, Any]],
"""Dictionary of component names to their properties to set. For example:
`{"MyScript": {"otherObject": {"find": "Player", "method": "by_name"}}}` assigns GameObject
`{"MyScript": {"playerHealth": {"find": "Player", "component": "HealthComponent"}}}` assigns Component
Example set nested property:
- Access shared material: `{"MeshRenderer": {"sharedMaterial.color": [1, 0, 0, 1]}}`"""] | None = None,
# --- Parameters for 'find' --- # --- Parameters for 'find' ---
search_term: str = None, search_term: Annotated[str,
find_all: bool = False, "Search term for 'find' action"] | None = None,
search_in_children: bool = False, find_all: Annotated[bool,
search_inactive: bool = False, "If True, finds all GameObjects matching the search term"] | None = None,
search_in_children: Annotated[bool,
"If True, searches in children of the GameObject"] | None = None,
search_inactive: Annotated[bool,
"If True, searches inactive GameObjects"] | None = None,
# -- Component Management Arguments -- # -- Component Management Arguments --
component_name: str = None, component_name: Annotated[str,
includeNonPublicSerialized: bool = None, # Controls serialization of private [SerializeField] fields "Component name for 'add_component' and 'remove_component' actions"] | None = None,
) -> Dict[str, Any]: # Controls whether serialization of private [SerializeField] fields is included
"""Manages GameObjects: create, modify, delete, find, and component operations. includeNonPublicSerialized: Annotated[bool,
"Controls whether serialization of private [SerializeField] fields is included"] | None = None,
Args: ) -> dict[str, Any]:
action: Operation (e.g., 'create', 'modify', 'find', 'add_component', 'remove_component', 'set_component_property', 'get_components'). ctx.info(f"Processing manage_gameobject: {action}")
target: GameObject identifier (name or path string) for modify/delete/component actions.
search_method: How to find objects ('by_name', 'by_id', 'by_path', etc.). Used with 'find' and some 'target' lookups.
name: GameObject name - used for both 'create' (initial name) and 'modify' (rename).
tag: Tag name - used for both 'create' (initial tag) and 'modify' (change tag).
parent: Parent GameObject reference - used for both 'create' (initial parent) and 'modify' (change parent).
layer: Layer name - used for both 'create' (initial layer) and 'modify' (change layer).
component_properties: Dict mapping Component names to their properties to set.
Example: {"Rigidbody": {"mass": 10.0, "useGravity": True}},
To set references:
- Use asset path string for Prefabs/Materials, e.g., {"MeshRenderer": {"material": "Assets/Materials/MyMat.mat"}}
- Use a dict for scene objects/components, e.g.:
{"MyScript": {"otherObject": {"find": "Player", "method": "by_name"}}} (assigns GameObject)
{"MyScript": {"playerHealth": {"find": "Player", "component": "HealthComponent"}}} (assigns Component)
Example set nested property:
- Access shared material: {"MeshRenderer": {"sharedMaterial.color": [1, 0, 0, 1]}}
components_to_add: List of component names to add.
Action-specific arguments (e.g., position, rotation, scale for create/modify;
component_name for component actions;
search_term, find_all for 'find').
includeNonPublicSerialized: If True, includes private fields marked [SerializeField] in component data.
Action-specific details:
- For 'get_components':
Required: target, search_method
Optional: includeNonPublicSerialized (defaults to True)
Returns all components on the target GameObject with their serialized data.
The search_method parameter determines how to find the target ('by_name', 'by_id', 'by_path').
Returns:
Dictionary with operation results ('success', 'message', 'data').
For 'get_components', the 'data' field contains a dictionary of component names and their serialized properties.
"""
try: try:
# --- Early check for attempting to modify a prefab asset ---
# ----------------------------------------------------------
# Prepare parameters, removing None values # Prepare parameters, removing None values
params = { params = {
"action": action, "action": action,
@ -110,9 +97,10 @@ def register_manage_gameobject_tools(mcp: FastMCP):
"includeNonPublicSerialized": includeNonPublicSerialized "includeNonPublicSerialized": includeNonPublicSerialized
} }
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
# --- Handle Prefab Path Logic --- # --- Handle Prefab Path Logic ---
if action == "create" and params.get("saveAsPrefab"): # Check if 'saveAsPrefab' is explicitly True in params # Check if 'saveAsPrefab' is explicitly True in params
if action == "create" and params.get("saveAsPrefab"):
if "prefabPath" not in params: if "prefabPath" not in params:
if "name" not in params or not params["name"]: if "name" not in params or not params["name"]:
return {"success": False, "message": "Cannot create default prefab path: 'name' parameter is missing."} return {"success": False, "message": "Cannot create default prefab path: 'name' parameter is missing."}
@ -124,9 +112,9 @@ def register_manage_gameobject_tools(mcp: FastMCP):
return {"success": False, "message": f"Invalid prefab_path: '{params['prefabPath']}' must end with .prefab"} return {"success": False, "message": f"Invalid prefab_path: '{params['prefabPath']}' must end with .prefab"}
# Ensure prefabFolder itself isn't sent if prefabPath was constructed or provided # Ensure prefabFolder itself isn't sent if prefabPath was constructed or provided
# The C# side only needs the final prefabPath # The C# side only needs the final prefabPath
params.pop("prefabFolder", None) params.pop("prefabFolder", None)
# -------------------------------- # --------------------------------
# Use centralized retry helper # Use centralized retry helper
response = send_command_with_retry("manage_gameobject", params) response = send_command_with_retry("manage_gameobject", params)
@ -137,4 +125,4 @@ def register_manage_gameobject_tools(mcp: FastMCP):
return response if isinstance(response, dict) else {"success": False, "message": str(response)} return response if isinstance(response, dict) else {"success": False, "message": str(response)}
except Exception as e: except Exception as e:
return {"success": False, "message": f"Python error managing GameObject: {str(e)}"} return {"success": False, "message": f"Python error managing GameObject: {str(e)}"}

View File

@ -7,24 +7,25 @@ from typing import Annotated, Any, Literal
from mcp.server.fastmcp import FastMCP, Context from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import get_unity_connection, async_send_command_with_retry from unity_connection import async_send_command_with_retry
def register_manage_menu_item_tools(mcp: FastMCP): def register_manage_menu_item_tools(mcp: FastMCP):
"""Registers the manage_menu_item tool with the MCP server.""" """Registers the manage_menu_item tool with the MCP server."""
@mcp.tool(description="Manage Unity menu items (execute/list/exists). If you're not sure what menu item to use, use the 'list' action to find it before using 'execute'.") @mcp.tool(name="manage_menu_item", description="Manage Unity menu items (execute/list/exists). If you're not sure what menu item to use, use the 'list' action to find it before using 'execute'.")
@telemetry_tool("manage_menu_item") @telemetry_tool("manage_menu_item")
async def manage_menu_item( async def manage_menu_item(
ctx: Context, ctx: Context,
action: Annotated[Literal["execute", "list", "exists"], "One of 'execute', 'list', 'exists'"], action: Annotated[Literal["execute", "list", "exists"], "Read and execute Unity menu items."],
menu_path: Annotated[str | None, menu_path: Annotated[str,
"Menu path for 'execute' or 'exists' (e.g., 'File/Save Project')"] = None, "Menu path for 'execute' or 'exists' (e.g., 'File/Save Project')"] | None = None,
search: Annotated[str | None, search: Annotated[str,
"Optional filter string for 'list' (e.g., 'Save')"] = None, "Optional filter string for 'list' (e.g., 'Save')"] | None = None,
refresh: Annotated[bool | None, refresh: Annotated[bool,
"Optional flag to force refresh of the menu cache when listing"] = None, "Optional flag to force refresh of the menu cache when listing"] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
ctx.info(f"Processing manage_menu_item: {action}")
# Prepare parameters for the C# handler # Prepare parameters for the C# handler
params_dict: dict[str, Any] = { params_dict: dict[str, Any] = {
"action": action, "action": action,
@ -37,8 +38,6 @@ def register_manage_menu_item_tools(mcp: FastMCP):
# Get the current asyncio event loop # Get the current asyncio event loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Touch the connection to ensure availability (mirrors other tools' pattern)
_ = get_unity_connection()
# Use centralized async retry helper # Use centralized async retry helper
result = await async_send_command_with_retry("manage_menu_item", params_dict, loop=loop) result = await async_send_command_with_retry("manage_menu_item", params_dict, loop=loop)

View File

@ -1,14 +1,15 @@
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry from unity_connection import send_command_with_retry
def register_manage_prefabs_tools(mcp: FastMCP) -> None: def register_manage_prefabs_tools(mcp: FastMCP) -> None:
"""Register prefab management tools with the MCP server.""" """Register prefab management tools with the MCP server."""
@mcp.tool(description="Bridge for prefab management commands (stage control and creation).") @mcp.tool(name="manage_prefabs", description="Bridge for prefab management commands (stage control and creation).")
@telemetry_tool("manage_prefabs") @telemetry_tool("manage_prefabs")
def manage_prefabs( def manage_prefabs(
ctx: Context, ctx: Context,
@ -17,20 +18,21 @@ def register_manage_prefabs_tools(mcp: FastMCP) -> None:
"close_stage", "close_stage",
"save_open_stage", "save_open_stage",
"create_from_gameobject", "create_from_gameobject",
], "One of open_stage, close_stage, save_open_stage, create_from_gameobject"], ], "Manage prefabs (stage control and creation)."],
prefab_path: Annotated[str | None, prefab_path: Annotated[str,
"Prefab asset path relative to Assets e.g. Assets/Prefabs/favorite.prefab"] = None, "Prefab asset path relative to Assets e.g. Assets/Prefabs/favorite.prefab"] | None = None,
mode: Annotated[str | None, mode: Annotated[str,
"Optional prefab stage mode (only 'InIsolation' is currently supported)"] = None, "Optional prefab stage mode (only 'InIsolation' is currently supported)"] | None = None,
save_before_close: Annotated[bool | None, save_before_close: Annotated[bool,
"When true, `close_stage` will save the prefab before exiting the stage."] = None, "When true, `close_stage` will save the prefab before exiting the stage."] | None = None,
target: Annotated[str | None, target: Annotated[str,
"Scene GameObject name required for create_from_gameobject"] = None, "Scene GameObject name required for create_from_gameobject"] | None = None,
allow_overwrite: Annotated[bool | None, allow_overwrite: Annotated[bool,
"Allow replacing an existing prefab at the same path"] = None, "Allow replacing an existing prefab at the same path"] | None = None,
search_inactive: Annotated[bool | None, search_inactive: Annotated[bool,
"Include inactive objects when resolving the target name"] = None, "Include inactive objects when resolving the target name"] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
ctx.info(f"Processing manage_prefabs: {action}")
try: try:
params: dict[str, Any] = {"action": action} params: dict[str, Any] = {"action": action}

View File

@ -1,35 +1,27 @@
from mcp.server.fastmcp import FastMCP, Context from typing import Annotated, Literal, Any
from typing import Dict, Any
from unity_connection import get_unity_connection, send_command_with_retry
from config import config
import time
from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry
def register_manage_scene_tools(mcp: FastMCP): def register_manage_scene_tools(mcp: FastMCP):
"""Register all scene management tools with the MCP server.""" """Register all scene management tools with the MCP server."""
@mcp.tool() @mcp.tool(name="manage_scene", description="Manage Unity scenes")
@telemetry_tool("manage_scene") @telemetry_tool("manage_scene")
def manage_scene( def manage_scene(
ctx: Context, ctx: Context,
action: str, action: Annotated[Literal["create", "load", "save", "get_hierarchy", "get_active", "get_build_settings"], "Perform CRUD operations on Unity scenes."],
name: str = "", name: Annotated[str,
path: str = "", "Scene name. Not required get_active/get_build_settings"] | None = None,
build_index: Any = None, path: Annotated[str,
) -> Dict[str, Any]: "Asset path for scene operations (default: 'Assets/')"] | None = None,
"""Manages Unity scenes (load, save, create, get hierarchy, etc.). build_index: Annotated[int,
"Build index for load/build settings actions"] | None = None,
Args: ) -> dict[str, Any]:
action: Operation (e.g., 'load', 'save', 'create', 'get_hierarchy'). ctx.info(f"Processing manage_scene: {action}")
name: Scene name (no extension) for create/load/save.
path: Asset path for scene operations (default: "Assets/").
build_index: Build index for load/build settings actions.
# Add other action-specific args as needed (e.g., for hierarchy depth)
Returns:
Dictionary with results ('success', 'message', 'data').
"""
try: try:
# Coerce numeric inputs defensively # Coerce numeric inputs defensively
def _coerce_int(value, default=None): def _coerce_int(value, default=None):
@ -56,7 +48,7 @@ def register_manage_scene_tools(mcp: FastMCP):
params["path"] = path params["path"] = path
if coerced_build_index is not None: if coerced_build_index is not None:
params["buildIndex"] = coerced_build_index params["buildIndex"] = coerced_build_index
# Use centralized retry helper # Use centralized retry helper
response = send_command_with_retry("manage_scene", params) response = send_command_with_retry("manage_scene", params)
@ -66,4 +58,4 @@ def register_manage_scene_tools(mcp: FastMCP):
return response if isinstance(response, dict) else {"success": False, "message": str(response)} return response if isinstance(response, dict) else {"success": False, "message": str(response)}
except Exception as e: except Exception as e:
return {"success": False, "message": f"Python error managing scene: {str(e)}"} return {"success": False, "message": f"Python error managing scene: {str(e)}"}

View File

@ -1,21 +1,24 @@
from mcp.server.fastmcp import FastMCP, Context
from typing import Dict, Any, List
from unity_connection import send_command_with_retry
import base64 import base64
import os import os
from typing import Annotated, Any, Literal
from urllib.parse import urlparse, unquote from urllib.parse import urlparse, unquote
from mcp.server.fastmcp import FastMCP, Context
from unity_connection import send_command_with_retry
try: try:
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from telemetry import record_milestone, MilestoneType
HAS_TELEMETRY = True HAS_TELEMETRY = True
except ImportError: except ImportError:
HAS_TELEMETRY = False HAS_TELEMETRY = False
def telemetry_tool(tool_name: str): def telemetry_tool(tool_name: str):
def decorator(func): def decorator(func):
return func return func
return decorator return decorator
def register_manage_script_tools(mcp: FastMCP): def register_manage_script_tools(mcp: FastMCP):
"""Register all script management tools with the MCP server.""" """Register all script management tools with the MCP server."""
@ -32,7 +35,7 @@ def register_manage_script_tools(mcp: FastMCP):
""" """
raw_path: str raw_path: str
if uri.startswith("unity://path/"): if uri.startswith("unity://path/"):
raw_path = uri[len("unity://path/") :] raw_path = uri[len("unity://path/"):]
elif uri.startswith("file://"): elif uri.startswith("file://"):
parsed = urlparse(uri) parsed = urlparse(uri)
host = (parsed.netloc or "").strip() host = (parsed.netloc or "").strip()
@ -56,7 +59,8 @@ def register_manage_script_tools(mcp: FastMCP):
# If an 'Assets' segment exists, compute path relative to it (case-insensitive) # If an 'Assets' segment exists, compute path relative to it (case-insensitive)
parts = [p for p in norm.split("/") if p not in ("", ".")] parts = [p for p in norm.split("/") if p not in ("", ".")]
idx = next((i for i, seg in enumerate(parts) if seg.lower() == "assets"), None) idx = next((i for i, seg in enumerate(parts)
if seg.lower() == "assets"), None)
assets_rel = "/".join(parts[idx:]) if idx is not None else None assets_rel = "/".join(parts[idx:]) if idx is not None else None
effective_path = assets_rel if assets_rel else norm effective_path = assets_rel if assets_rel else norm
@ -69,51 +73,47 @@ def register_manage_script_tools(mcp: FastMCP):
directory = os.path.dirname(effective_path) directory = os.path.dirname(effective_path)
return name, directory return name, directory
@mcp.tool(description=( @mcp.tool(name="apply_text_edits", description=(
"Apply small text edits to a C# script identified by URI.\n\n" """Apply small text edits to a C# script identified by URI.
"⚠️ IMPORTANT: This tool replaces EXACT character positions. Always verify content at target lines/columns BEFORE editing!\n" IMPORTANT: This tool replaces EXACT character positions. Always verify content at target lines/columns BEFORE editing!
"Common mistakes:\n" RECOMMENDED WORKFLOW:
"- Assuming what's on a line without checking\n" 1. First call resources/read with start_line/line_count to verify exact content
"- Using wrong line numbers (they're 1-indexed)\n" 2. Count columns carefully (or use find_in_file to locate patterns)
"- Miscounting column positions (also 1-indexed, tabs count as 1)\n\n" 3. Apply your edit with precise coordinates
"RECOMMENDED WORKFLOW:\n" 4. Consider script_apply_edits with anchors for safer pattern-based replacements
"1) First call resources/read with start_line/line_count to verify exact content\n" Notes:
"2) Count columns carefully (or use find_in_file to locate patterns)\n" - For method/class operations, use script_apply_edits (safer, structured edits)
"3) Apply your edit with precise coordinates\n" - For pattern-based replacements, consider anchor operations in script_apply_edits
"4) Consider script_apply_edits with anchors for safer pattern-based replacements\n\n" - Lines, columns are 1-indexed
"Args:\n" - Tabs count as 1 column"""
"- uri: unity://path/Assets/... or file://... or Assets/...\n"
"- edits: list of {startLine,startCol,endLine,endCol,newText} (1-indexed!)\n"
"- precondition_sha256: optional SHA of current file (prevents concurrent edit conflicts)\n\n"
"Notes:\n"
"- Path must resolve under Assets/\n"
"- For method/class operations, use script_apply_edits (safer, structured edits)\n"
"- For pattern-based replacements, consider anchor operations in script_apply_edits\n"
)) ))
@telemetry_tool("apply_text_edits") @telemetry_tool("apply_text_edits")
def apply_text_edits( def apply_text_edits(
ctx: Context, ctx: Context,
uri: str, uri: Annotated[str, "URI of the script to edit under Assets/ directory, unity://path/Assets/... or file://... or Assets/..."],
edits: List[Dict[str, Any]], edits: Annotated[list[dict[str, Any]], "List of edits to apply to the script, i.e. a list of {startLine,startCol,endLine,endCol,newText} (1-indexed!)"],
precondition_sha256: str | None = None, precondition_sha256: Annotated[str,
strict: bool | None = None, "Optional SHA256 of the script to edit, used to prevent concurrent edits"] | None = None,
options: Dict[str, Any] | None = None, strict: Annotated[bool,
) -> Dict[str, Any]: "Optional strict flag, used to enforce strict mode"] | None = None,
"""Apply small text edits to a C# script identified by URI.""" options: Annotated[dict[str, Any],
"Optional options, used to pass additional options to the script editor"] | None = None,
) -> dict[str, Any]:
ctx.info(f"Processing apply_text_edits: {uri}")
name, directory = _split_uri(uri) name, directory = _split_uri(uri)
# Normalize common aliases/misuses for resilience: # Normalize common aliases/misuses for resilience:
# - Accept LSP-style range objects: {range:{start:{line,character}, end:{...}}, newText|text} # - Accept LSP-style range objects: {range:{start:{line,character}, end:{...}}, newText|text}
# - Accept index ranges as a 2-int array: {range:[startIndex,endIndex], text} # - Accept index ranges as a 2-int array: {range:[startIndex,endIndex], text}
# If normalization is required, read current contents to map indices -> 1-based line/col. # If normalization is required, read current contents to map indices -> 1-based line/col.
def _needs_normalization(arr: List[Dict[str, Any]]) -> bool: def _needs_normalization(arr: list[dict[str, Any]]) -> bool:
for e in arr or []: for e in arr or []:
if ("startLine" not in e) or ("startCol" not in e) or ("endLine" not in e) or ("endCol" not in e) or ("newText" not in e and "text" in e): if ("startLine" not in e) or ("startCol" not in e) or ("endLine" not in e) or ("endCol" not in e) or ("newText" not in e and "text" in e):
return True return True
return False return False
normalized_edits: List[Dict[str, Any]] = [] normalized_edits: list[dict[str, Any]] = []
warnings: List[str] = [] warnings: list[str] = []
if _needs_normalization(edits): if _needs_normalization(edits):
# Read file to support index->line/col conversion when needed # Read file to support index->line/col conversion when needed
read_resp = send_command_with_retry("manage_script", { read_resp = send_command_with_retry("manage_script", {
@ -127,7 +127,8 @@ def register_manage_script_tools(mcp: FastMCP):
contents = data.get("contents") contents = data.get("contents")
if not contents and data.get("contentsEncoded"): if not contents and data.get("contentsEncoded"):
try: try:
contents = base64.b64decode(data.get("encodedContents", "").encode("utf-8")).decode("utf-8", "replace") contents = base64.b64decode(data.get("encodedContents", "").encode(
"utf-8")).decode("utf-8", "replace")
except Exception: except Exception:
contents = contents or "" contents = contents or ""
@ -151,7 +152,7 @@ def register_manage_script_tools(mcp: FastMCP):
if "startLine" in e2 and "startCol" in e2 and "endLine" in e2 and "endCol" in e2: if "startLine" in e2 and "startCol" in e2 and "endLine" in e2 and "endCol" in e2:
# Guard: explicit fields must be 1-based. # Guard: explicit fields must be 1-based.
zero_based = False zero_based = False
for k in ("startLine","startCol","endLine","endCol"): for k in ("startLine", "startCol", "endLine", "endCol"):
try: try:
if int(e2.get(k, 1)) < 1: if int(e2.get(k, 1)) < 1:
zero_based = True zero_based = True
@ -161,13 +162,14 @@ def register_manage_script_tools(mcp: FastMCP):
if strict: if strict:
return {"success": False, "code": "zero_based_explicit_fields", "message": "Explicit line/col fields are 1-based; received zero-based.", "data": {"normalizedEdits": normalized_edits}} return {"success": False, "code": "zero_based_explicit_fields", "message": "Explicit line/col fields are 1-based; received zero-based.", "data": {"normalizedEdits": normalized_edits}}
# Normalize by clamping to 1 and warn # Normalize by clamping to 1 and warn
for k in ("startLine","startCol","endLine","endCol"): for k in ("startLine", "startCol", "endLine", "endCol"):
try: try:
if int(e2.get(k, 1)) < 1: if int(e2.get(k, 1)) < 1:
e2[k] = 1 e2[k] = 1
except Exception: except Exception:
pass pass
warnings.append("zero_based_explicit_fields_normalized") warnings.append(
"zero_based_explicit_fields_normalized")
normalized_edits.append(e2) normalized_edits.append(e2)
continue continue
@ -205,17 +207,18 @@ def register_manage_script_tools(mcp: FastMCP):
"success": False, "success": False,
"code": "missing_field", "code": "missing_field",
"message": "apply_text_edits requires startLine/startCol/endLine/endCol/newText or a normalizable 'range'", "message": "apply_text_edits requires startLine/startCol/endLine/endCol/newText or a normalizable 'range'",
"data": {"expected": ["startLine","startCol","endLine","endCol","newText"], "got": e} "data": {"expected": ["startLine", "startCol", "endLine", "endCol", "newText"], "got": e}
} }
else: else:
# Even when edits appear already in explicit form, validate 1-based coordinates. # Even when edits appear already in explicit form, validate 1-based coordinates.
normalized_edits = [] normalized_edits = []
for e in edits or []: for e in edits or []:
e2 = dict(e) e2 = dict(e)
has_all = all(k in e2 for k in ("startLine","startCol","endLine","endCol")) has_all = all(k in e2 for k in (
"startLine", "startCol", "endLine", "endCol"))
if has_all: if has_all:
zero_based = False zero_based = False
for k in ("startLine","startCol","endLine","endCol"): for k in ("startLine", "startCol", "endLine", "endCol"):
try: try:
if int(e2.get(k, 1)) < 1: if int(e2.get(k, 1)) < 1:
zero_based = True zero_based = True
@ -224,21 +227,24 @@ def register_manage_script_tools(mcp: FastMCP):
if zero_based: if zero_based:
if strict: if strict:
return {"success": False, "code": "zero_based_explicit_fields", "message": "Explicit line/col fields are 1-based; received zero-based.", "data": {"normalizedEdits": [e2]}} return {"success": False, "code": "zero_based_explicit_fields", "message": "Explicit line/col fields are 1-based; received zero-based.", "data": {"normalizedEdits": [e2]}}
for k in ("startLine","startCol","endLine","endCol"): for k in ("startLine", "startCol", "endLine", "endCol"):
try: try:
if int(e2.get(k, 1)) < 1: if int(e2.get(k, 1)) < 1:
e2[k] = 1 e2[k] = 1
except Exception: except Exception:
pass pass
if "zero_based_explicit_fields_normalized" not in warnings: if "zero_based_explicit_fields_normalized" not in warnings:
warnings.append("zero_based_explicit_fields_normalized") warnings.append(
"zero_based_explicit_fields_normalized")
normalized_edits.append(e2) normalized_edits.append(e2)
# Preflight: detect overlapping ranges among normalized line/col spans # Preflight: detect overlapping ranges among normalized line/col spans
def _pos_tuple(e: Dict[str, Any], key_start: bool) -> tuple[int, int]: def _pos_tuple(e: dict[str, Any], key_start: bool) -> tuple[int, int]:
return ( return (
int(e.get("startLine", 1)) if key_start else int(e.get("endLine", 1)), int(e.get("startLine", 1)) if key_start else int(
int(e.get("startCol", 1)) if key_start else int(e.get("endCol", 1)), e.get("endLine", 1)),
int(e.get("startCol", 1)) if key_start else int(
e.get("endCol", 1)),
) )
def _le(a: tuple[int, int], b: tuple[int, int]) -> bool: def _le(a: tuple[int, int], b: tuple[int, int]) -> bool:
@ -276,7 +282,7 @@ def register_manage_script_tools(mcp: FastMCP):
# preserves existing call-count expectations in clients/tests. # preserves existing call-count expectations in clients/tests.
# Default options: for multi-span batches, prefer atomic to avoid mid-apply imbalance # Default options: for multi-span batches, prefer atomic to avoid mid-apply imbalance
opts: Dict[str, Any] = dict(options or {}) opts: dict[str, Any] = dict(options or {})
try: try:
if len(normalized_edits) > 1 and "applyMode" not in opts: if len(normalized_edits) > 1 and "applyMode" not in opts:
opts["applyMode"] = "atomic" opts["applyMode"] = "atomic"
@ -320,10 +326,16 @@ def register_manage_script_tools(mcp: FastMCP):
if resp.get("success") and (options or {}).get("force_sentinel_reload"): if resp.get("success") and (options or {}).get("force_sentinel_reload"):
# Optional: flip sentinel via menu if explicitly requested # Optional: flip sentinel via menu if explicitly requested
try: try:
import threading, time, json, glob, os import threading
import time
import json
import glob
import os
def _latest_status() -> dict | None: def _latest_status() -> dict | None:
try: try:
files = sorted(glob.glob(os.path.expanduser("~/.unity-mcp/unity-mcp-status-*.json")), key=os.path.getmtime, reverse=True) files = sorted(glob.glob(os.path.expanduser(
"~/.unity-mcp/unity-mcp-status-*.json")), key=os.path.getmtime, reverse=True)
if not files: if not files:
return None return None
with open(files[0], "r") as f: with open(files[0], "r") as f:
@ -352,24 +364,21 @@ def register_manage_script_tools(mcp: FastMCP):
return resp return resp
return {"success": False, "message": str(resp)} return {"success": False, "message": str(resp)}
@mcp.tool(description=( @mcp.tool(name="create_script", description=("Create a new C# script at the given project path."))
"Create a new C# script at the given project path.\n\n"
"Args: path (e.g., 'Assets/Scripts/My.cs'), contents (string), script_type, namespace.\n"
"Rules: path must be under Assets/. Contents will be Base64-encoded over transport.\n"
))
@telemetry_tool("create_script") @telemetry_tool("create_script")
def create_script( def create_script(
ctx: Context, ctx: Context,
path: str, path: Annotated[str, "Path under Assets/ to create the script at, e.g., 'Assets/Scripts/My.cs'"],
contents: str = "", contents: Annotated[str, "Contents of the script to create. Note, this is Base64 encoded over transport."],
script_type: str | None = None, script_type: Annotated[str, "Script type (e.g., 'C#')"] | None = None,
namespace: str | None = None, namespace: Annotated[str, "Namespace for the script"] | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Create a new C# script at the given path.""" ctx.info(f"Processing create_script: {path}")
name = os.path.splitext(os.path.basename(path))[0] name = os.path.splitext(os.path.basename(path))[0]
directory = os.path.dirname(path) directory = os.path.dirname(path)
# Local validation to avoid round-trips on obviously bad input # Local validation to avoid round-trips on obviously bad input
norm_path = os.path.normpath((path or "").replace("\\", "/")).replace("\\", "/") norm_path = os.path.normpath(
(path or "").replace("\\", "/")).replace("\\", "/")
if not directory or directory.split("/")[0].lower() != "assets": if not directory or directory.split("/")[0].lower() != "assets":
return {"success": False, "code": "path_outside_assets", "message": f"path must be under 'Assets/'; got '{path}'."} return {"success": False, "code": "path_outside_assets", "message": f"path must be under 'Assets/'; got '{path}'."}
if ".." in norm_path.split("/") or norm_path.startswith("/"): if ".." in norm_path.split("/") or norm_path.startswith("/"):
@ -378,7 +387,7 @@ def register_manage_script_tools(mcp: FastMCP):
return {"success": False, "code": "bad_path", "message": "path must include a script file name."} return {"success": False, "code": "bad_path", "message": "path must include a script file name."}
if not norm_path.lower().endswith(".cs"): if not norm_path.lower().endswith(".cs"):
return {"success": False, "code": "bad_extension", "message": "script file must end with .cs."} return {"success": False, "code": "bad_extension", "message": "script file must end with .cs."}
params: Dict[str, Any] = { params: dict[str, Any] = {
"action": "create", "action": "create",
"name": name, "name": name,
"path": directory, "path": directory,
@ -386,20 +395,21 @@ def register_manage_script_tools(mcp: FastMCP):
"scriptType": script_type, "scriptType": script_type,
} }
if contents: if contents:
params["encodedContents"] = base64.b64encode(contents.encode("utf-8")).decode("utf-8") params["encodedContents"] = base64.b64encode(
contents.encode("utf-8")).decode("utf-8")
params["contentsEncoded"] = True params["contentsEncoded"] = True
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
resp = send_command_with_retry("manage_script", params) resp = send_command_with_retry("manage_script", params)
return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)}
@mcp.tool(description=( @mcp.tool(name="delete_script", description=("Delete a C# script by URI or Assets-relative path."))
"Delete a C# script by URI or Assets-relative path.\n\n"
"Args: uri (unity://path/... or file://... or Assets/...).\n"
"Rules: Target must resolve under Assets/.\n"
))
@telemetry_tool("delete_script") @telemetry_tool("delete_script")
def delete_script(ctx: Context, uri: str) -> Dict[str, Any]: def delete_script(
ctx: Context,
uri: Annotated[str, "URI of the script to delete under Assets/ directory, unity://path/Assets/... or file://... or Assets/..."]
) -> dict[str, Any]:
"""Delete a C# script by URI.""" """Delete a C# script by URI."""
ctx.info(f"Processing delete_script: {uri}")
name, directory = _split_uri(uri) name, directory = _split_uri(uri)
if not directory or directory.split("/")[0].lower() != "assets": if not directory or directory.split("/")[0].lower() != "assets":
return {"success": False, "code": "path_outside_assets", "message": "URI must resolve under 'Assets/'."} return {"success": False, "code": "path_outside_assets", "message": "URI must resolve under 'Assets/'."}
@ -407,18 +417,17 @@ def register_manage_script_tools(mcp: FastMCP):
resp = send_command_with_retry("manage_script", params) resp = send_command_with_retry("manage_script", params)
return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)}
@mcp.tool(description=( @mcp.tool(name="validate_script", description=("Validate a C# script and return diagnostics."))
"Validate a C# script and return diagnostics.\n\n"
"Args: uri, level=('basic'|'standard'), include_diagnostics (bool, optional).\n"
"- basic: quick syntax checks.\n"
"- standard: deeper checks (performance hints, common pitfalls).\n"
"- include_diagnostics: when true, returns full diagnostics and summary; default returns counts only.\n"
))
@telemetry_tool("validate_script") @telemetry_tool("validate_script")
def validate_script( def validate_script(
ctx: Context, uri: str, level: str = "basic", include_diagnostics: bool = False ctx: Context,
) -> Dict[str, Any]: uri: Annotated[str, "URI of the script to validate under Assets/ directory, unity://path/Assets/... or file://... or Assets/..."],
"""Validate a C# script and return diagnostics.""" level: Annotated[Literal['basic', 'standard'],
"Validation level"] = "basic",
include_diagnostics: Annotated[bool,
"Include full diagnostics and summary"] = False
) -> dict[str, Any]:
ctx.info(f"Processing validate_script: {uri}")
name, directory = _split_uri(uri) name, directory = _split_uri(uri)
if not directory or directory.split("/")[0].lower() != "assets": if not directory or directory.split("/")[0].lower() != "assets":
return {"success": False, "code": "path_outside_assets", "message": "URI must resolve under 'Assets/'."} return {"success": False, "code": "path_outside_assets", "message": "URI must resolve under 'Assets/'."}
@ -433,103 +442,30 @@ def register_manage_script_tools(mcp: FastMCP):
resp = send_command_with_retry("manage_script", params) resp = send_command_with_retry("manage_script", params)
if isinstance(resp, dict) and resp.get("success"): if isinstance(resp, dict) and resp.get("success"):
diags = resp.get("data", {}).get("diagnostics", []) or [] diags = resp.get("data", {}).get("diagnostics", []) or []
warnings = sum(1 for d in diags if str(d.get("severity", "")).lower() == "warning") warnings = sum(1 for d in diags if str(
errors = sum(1 for d in diags if str(d.get("severity", "")).lower() in ("error", "fatal")) d.get("severity", "")).lower() == "warning")
errors = sum(1 for d in diags if str(
d.get("severity", "")).lower() in ("error", "fatal"))
if include_diagnostics: if include_diagnostics:
return {"success": True, "data": {"diagnostics": diags, "summary": {"warnings": warnings, "errors": errors}}} return {"success": True, "data": {"diagnostics": diags, "summary": {"warnings": warnings, "errors": errors}}}
return {"success": True, "data": {"warnings": warnings, "errors": errors}} return {"success": True, "data": {"warnings": warnings, "errors": errors}}
return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)}
@mcp.tool(description=( @mcp.tool(name="manage_script", description=("Compatibility router for legacy script operations. Prefer apply_text_edits (ranges) or script_apply_edits (structured) for edits."))
"Compatibility router for legacy script operations.\n\n"
"Actions: create|read|delete (update is routed to apply_text_edits with precondition).\n"
"Args: name (no .cs), path (Assets/...), contents (for create), script_type, namespace.\n"
"Notes: prefer apply_text_edits (ranges) or script_apply_edits (structured) for edits.\n"
))
@telemetry_tool("manage_script") @telemetry_tool("manage_script")
def manage_script( def manage_script(
ctx: Context, ctx: Context,
action: str, action: Annotated[Literal['create', 'read', 'delete'], "Perform CRUD operations on C# scripts."],
name: str, name: Annotated[str, "Script name (no .cs extension)", "Name of the script to create"],
path: str, path: Annotated[str, "Asset path (default: 'Assets/')", "Path under Assets/ to create the script at, e.g., 'Assets/Scripts/My.cs'"],
contents: str = "", contents: Annotated[str, "Contents of the script to create",
script_type: str | None = None, "C# code for 'create'/'update'"] | None = None,
namespace: str | None = None, script_type: Annotated[str, "Script type (e.g., 'C#')",
) -> Dict[str, Any]: "Type hint (e.g., 'MonoBehaviour')"] | None = None,
"""Compatibility router for legacy script operations. namespace: Annotated[str, "Namespace for the script"] | None = None,
) -> dict[str, Any]:
IMPORTANT: ctx.info(f"Processing manage_script: {action}")
- Direct file reads should use resources/read.
- Edits should use apply_text_edits.
Args:
action: Operation ('create', 'read', 'delete').
name: Script name (no .cs extension).
path: Asset path (default: "Assets/").
contents: C# code for 'create'/'update'.
script_type: Type hint (e.g., 'MonoBehaviour').
namespace: Script namespace.
Returns:
Dictionary with results ('success', 'message', 'data').
"""
try: try:
# Graceful migration for legacy 'update': route to apply_text_edits (whole-file replace)
if action == 'update':
try:
# 1) Read current contents to compute end range and precondition
read_resp = send_command_with_retry("manage_script", {
"action": "read",
"name": name,
"path": path,
})
if not (isinstance(read_resp, dict) and read_resp.get("success")):
return {"success": False, "code": "deprecated_update", "message": "Use apply_text_edits; automatic migration failed to read current file."}
data = read_resp.get("data", {})
current = data.get("contents")
if not current and data.get("contentsEncoded"):
current = base64.b64decode(data.get("encodedContents", "").encode("utf-8")).decode("utf-8", "replace")
if current is None:
return {"success": False, "code": "deprecated_update", "message": "Use apply_text_edits; current file read returned no contents."}
# 2) Compute whole-file range (1-based, end exclusive) and SHA
import hashlib as _hashlib
old_lines = current.splitlines(keepends=True)
end_line = len(old_lines) + 1
sha = _hashlib.sha256(current.encode("utf-8")).hexdigest()
# 3) Apply single whole-file text edit with provided 'contents'
edits = [{
"startLine": 1,
"startCol": 1,
"endLine": end_line,
"endCol": 1,
"newText": contents or "",
}]
route_params = {
"action": "apply_text_edits",
"name": name,
"path": path,
"edits": edits,
"precondition_sha256": sha,
"options": {"refresh": "debounced", "validate": "standard"},
}
# Preflight size vs. default cap (256 KiB) to avoid opaque server errors
try:
import json as _json
payload_bytes = len(_json.dumps({"edits": edits}, ensure_ascii=False).encode("utf-8"))
if payload_bytes > 256 * 1024:
return {"success": False, "code": "payload_too_large", "message": f"Edit payload {payload_bytes} bytes exceeds 256 KiB cap; try structured ops or chunking."}
except Exception:
pass
routed = send_command_with_retry("manage_script", route_params)
if isinstance(routed, dict):
routed.setdefault("message", "Routed legacy update to apply_text_edits")
return routed
return {"success": False, "message": str(routed)}
except Exception as e:
return {"success": False, "code": "deprecated_update", "message": f"Use apply_text_edits; migration error: {e}"}
# Prepare parameters for Unity # Prepare parameters for Unity
params = { params = {
"action": action, "action": action,
@ -542,7 +478,8 @@ def register_manage_script_tools(mcp: FastMCP):
# Base64 encode the contents if they exist to avoid JSON escaping issues # Base64 encode the contents if they exist to avoid JSON escaping issues
if contents: if contents:
if action == 'create': if action == 'create':
params["encodedContents"] = base64.b64encode(contents.encode('utf-8')).decode('utf-8') params["encodedContents"] = base64.b64encode(
contents.encode('utf-8')).decode('utf-8')
params["contentsEncoded"] = True params["contentsEncoded"] = True
else: else:
params["contents"] = contents params["contents"] = contents
@ -554,7 +491,8 @@ def register_manage_script_tools(mcp: FastMCP):
if isinstance(response, dict): if isinstance(response, dict):
if response.get("success"): if response.get("success"):
if response.get("data", {}).get("contentsEncoded"): if response.get("data", {}).get("contentsEncoded"):
decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8') decoded_contents = base64.b64decode(
response["data"]["encodedContents"]).decode('utf-8')
response["data"]["contents"] = decoded_contents response["data"]["contents"] = decoded_contents
del response["data"]["encodedContents"] del response["data"]["encodedContents"]
del response["data"]["contentsEncoded"] del response["data"]["contentsEncoded"]
@ -574,19 +512,24 @@ def register_manage_script_tools(mcp: FastMCP):
"message": f"Python error managing script: {str(e)}", "message": f"Python error managing script: {str(e)}",
} }
@mcp.tool(description=( @mcp.tool(name="manage_script_capabilities", description=(
"Get manage_script capabilities (supported ops, limits, and guards).\n\n" """Get manage_script capabilities (supported ops, limits, and guards).
"Returns:\n- ops: list of supported structured ops\n- text_ops: list of supported text ops\n- max_edit_payload_bytes: server edit payload cap\n- guards: header/using guard enabled flag\n" Returns:
- ops: list of supported structured ops
- text_ops: list of supported text ops
- max_edit_payload_bytes: server edit payload cap
- guards: header/using guard enabled flag"""
)) ))
@telemetry_tool("manage_script_capabilities") @telemetry_tool("manage_script_capabilities")
def manage_script_capabilities(ctx: Context) -> Dict[str, Any]: def manage_script_capabilities(ctx: Context) -> dict[str, Any]:
ctx.info("Processing manage_script_capabilities")
try: try:
# Keep in sync with server/Editor ManageScript implementation # Keep in sync with server/Editor ManageScript implementation
ops = [ ops = [
"replace_class","delete_class","replace_method","delete_method", "replace_class", "delete_class", "replace_method", "delete_method",
"insert_method","anchor_insert","anchor_delete","anchor_replace" "insert_method", "anchor_insert", "anchor_delete", "anchor_replace"
] ]
text_ops = ["replace_range","regex_replace","prepend","append"] text_ops = ["replace_range", "regex_replace", "prepend", "append"]
# Match ManageScript.MaxEditPayloadBytes if exposed; hardcode a sensible default fallback # Match ManageScript.MaxEditPayloadBytes if exposed; hardcode a sensible default fallback
max_edit_payload_bytes = 256 * 1024 max_edit_payload_bytes = 256 * 1024
guards = {"using_guard": True} guards = {"using_guard": True}
@ -601,21 +544,21 @@ def register_manage_script_tools(mcp: FastMCP):
except Exception as e: except Exception as e:
return {"success": False, "error": f"capabilities error: {e}"} return {"success": False, "error": f"capabilities error: {e}"}
@mcp.tool(description=( @mcp.tool(name="get_sha", description="Get SHA256 and basic metadata for a Unity C# script without returning file contents")
"Get SHA256 and basic metadata for a Unity C# script without returning file contents.\n\n"
"Args: uri (unity://path/Assets/... or file://... or Assets/...).\n"
"Returns: {sha256, lengthBytes}."
))
@telemetry_tool("get_sha") @telemetry_tool("get_sha")
def get_sha(ctx: Context, uri: str) -> Dict[str, Any]: def get_sha(
"""Return SHA256 and basic metadata for a script.""" ctx: Context,
uri: Annotated[str, "URI of the script to edit under Assets/ directory, unity://path/Assets/... or file://... or Assets/..."]
) -> dict[str, Any]:
ctx.info(f"Processing get_sha: {uri}")
try: try:
name, directory = _split_uri(uri) name, directory = _split_uri(uri)
params = {"action": "get_sha", "name": name, "path": directory} params = {"action": "get_sha", "name": name, "path": directory}
resp = send_command_with_retry("manage_script", params) resp = send_command_with_retry("manage_script", params)
if isinstance(resp, dict) and resp.get("success"): if isinstance(resp, dict) and resp.get("success"):
data = resp.get("data", {}) data = resp.get("data", {})
minimal = {"sha256": data.get("sha256"), "lengthBytes": data.get("lengthBytes")} minimal = {"sha256": data.get(
"sha256"), "lengthBytes": data.get("lengthBytes")}
return {"success": True, "data": minimal} return {"success": True, "data": minimal}
return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)}
except Exception as e: except Exception as e:

View File

@ -1,14 +1,15 @@
from mcp.server.fastmcp import FastMCP, Context
from typing import Dict, Any, List, Tuple, Optional
import base64 import base64
import hashlib
import re import re
import os from typing import Annotated, Any
from unity_connection import send_command_with_retry
from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry
def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str:
def _apply_edits_locally(original_text: str, edits: list[dict[str, Any]]) -> str:
text = original_text text = original_text
for edit in edits or []: for edit in edits or []:
op = ( op = (
@ -29,7 +30,8 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str
if op == "prepend": if op == "prepend":
prepend_text = edit.get("text", "") prepend_text = edit.get("text", "")
text = (prepend_text if prepend_text.endswith("\n") else prepend_text + "\n") + text text = (prepend_text if prepend_text.endswith(
"\n") else prepend_text + "\n") + text
elif op == "append": elif op == "append":
append_text = edit.get("text", "") append_text = edit.get("text", "")
if not text.endswith("\n"): if not text.endswith("\n"):
@ -41,10 +43,12 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str
anchor = edit.get("anchor", "") anchor = edit.get("anchor", "")
position = (edit.get("position") or "before").lower() position = (edit.get("position") or "before").lower()
insert_text = edit.get("text", "") insert_text = edit.get("text", "")
flags = re.MULTILINE | (re.IGNORECASE if edit.get("ignore_case") else 0) flags = re.MULTILINE | (
re.IGNORECASE if edit.get("ignore_case") else 0)
# Find the best match using improved heuristics # Find the best match using improved heuristics
match = _find_best_anchor_match(anchor, text, flags, bool(edit.get("prefer_last", True))) match = _find_best_anchor_match(
anchor, text, flags, bool(edit.get("prefer_last", True)))
if not match: if not match:
if edit.get("allow_noop", True): if edit.get("allow_noop", True):
continue continue
@ -53,15 +57,16 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str
text = text[:idx] + insert_text + text[idx:] text = text[:idx] + insert_text + text[idx:]
elif op == "replace_range": elif op == "replace_range":
start_line = int(edit.get("startLine", 1)) start_line = int(edit.get("startLine", 1))
start_col = int(edit.get("startCol", 1)) start_col = int(edit.get("startCol", 1))
end_line = int(edit.get("endLine", start_line)) end_line = int(edit.get("endLine", start_line))
end_col = int(edit.get("endCol", 1)) end_col = int(edit.get("endCol", 1))
replacement = edit.get("text", "") replacement = edit.get("text", "")
lines = text.splitlines(keepends=True) lines = text.splitlines(keepends=True)
max_line = len(lines) + 1 # 1-based, exclusive end max_line = len(lines) + 1 # 1-based, exclusive end
if (start_line < 1 or end_line < start_line or end_line > max_line if (start_line < 1 or end_line < start_line or end_line > max_line
or start_col < 1 or end_col < 1): or start_col < 1 or end_col < 1):
raise RuntimeError("replace_range out of bounds") raise RuntimeError("replace_range out of bounds")
def index_of(line: int, col: int) -> int: def index_of(line: int, col: int) -> int:
if line <= len(lines): if line <= len(lines):
return sum(len(l) for l in lines[: line - 1]) + (col - 1) return sum(len(l) for l in lines[: line - 1]) + (col - 1)
@ -81,48 +86,49 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str
text = re.sub(pattern, repl_py, text, count=count, flags=flags) text = re.sub(pattern, repl_py, text, count=count, flags=flags)
else: else:
allowed = "anchor_insert, prepend, append, replace_range, regex_replace" allowed = "anchor_insert, prepend, append, replace_range, regex_replace"
raise RuntimeError(f"unknown edit op: {op}; allowed: {allowed}. Use 'op' (aliases accepted: type/mode/operation).") raise RuntimeError(
f"unknown edit op: {op}; allowed: {allowed}. Use 'op' (aliases accepted: type/mode/operation).")
return text return text
def _find_best_anchor_match(pattern: str, text: str, flags: int, prefer_last: bool = True): def _find_best_anchor_match(pattern: str, text: str, flags: int, prefer_last: bool = True):
""" """
Find the best anchor match using improved heuristics. Find the best anchor match using improved heuristics.
For patterns like \\s*}\\s*$ that are meant to find class-ending braces, For patterns like \\s*}\\s*$ that are meant to find class-ending braces,
this function uses heuristics to choose the most semantically appropriate match: this function uses heuristics to choose the most semantically appropriate match:
1. If prefer_last=True, prefer the last match (common for class-end insertions) 1. If prefer_last=True, prefer the last match (common for class-end insertions)
2. Use indentation levels to distinguish class vs method braces 2. Use indentation levels to distinguish class vs method braces
3. Consider context to avoid matches inside strings/comments 3. Consider context to avoid matches inside strings/comments
Args: Args:
pattern: Regex pattern to search for pattern: Regex pattern to search for
text: Text to search in text: Text to search in
flags: Regex flags flags: Regex flags
prefer_last: If True, prefer the last match over the first prefer_last: If True, prefer the last match over the first
Returns: Returns:
Match object of the best match, or None if no match found Match object of the best match, or None if no match found
""" """
import re
# Find all matches # Find all matches
matches = list(re.finditer(pattern, text, flags)) matches = list(re.finditer(pattern, text, flags))
if not matches: if not matches:
return None return None
# If only one match, return it # If only one match, return it
if len(matches) == 1: if len(matches) == 1:
return matches[0] return matches[0]
# For patterns that look like they're trying to match closing braces at end of lines # For patterns that look like they're trying to match closing braces at end of lines
is_closing_brace_pattern = '}' in pattern and ('$' in pattern or pattern.endswith(r'\s*')) is_closing_brace_pattern = '}' in pattern and (
'$' in pattern or pattern.endswith(r'\s*'))
if is_closing_brace_pattern and prefer_last: if is_closing_brace_pattern and prefer_last:
# Use heuristics to find the best closing brace match # Use heuristics to find the best closing brace match
return _find_best_closing_brace_match(matches, text) return _find_best_closing_brace_match(matches, text)
# Default behavior: use last match if prefer_last, otherwise first match # Default behavior: use last match if prefer_last, otherwise first match
return matches[-1] if prefer_last else matches[0] return matches[-1] if prefer_last else matches[0]
@ -130,68 +136,70 @@ def _find_best_anchor_match(pattern: str, text: str, flags: int, prefer_last: bo
def _find_best_closing_brace_match(matches, text: str): def _find_best_closing_brace_match(matches, text: str):
""" """
Find the best closing brace match using C# structure heuristics. Find the best closing brace match using C# structure heuristics.
Enhanced heuristics for scope-aware matching: Enhanced heuristics for scope-aware matching:
1. Prefer matches with lower indentation (likely class-level) 1. Prefer matches with lower indentation (likely class-level)
2. Prefer matches closer to end of file 2. Prefer matches closer to end of file
3. Avoid matches that seem to be inside method bodies 3. Avoid matches that seem to be inside method bodies
4. For #endregion patterns, ensure class-level context 4. For #endregion patterns, ensure class-level context
5. Validate insertion point is at appropriate scope 5. Validate insertion point is at appropriate scope
Args: Args:
matches: List of regex match objects matches: List of regex match objects
text: The full text being searched text: The full text being searched
Returns: Returns:
The best match object The best match object
""" """
if not matches: if not matches:
return None return None
scored_matches = [] scored_matches = []
lines = text.splitlines() lines = text.splitlines()
for match in matches: for match in matches:
score = 0 score = 0
start_pos = match.start() start_pos = match.start()
# Find which line this match is on # Find which line this match is on
lines_before = text[:start_pos].count('\n') lines_before = text[:start_pos].count('\n')
line_num = lines_before line_num = lines_before
if line_num < len(lines): if line_num < len(lines):
line_content = lines[line_num] line_content = lines[line_num]
# Calculate indentation level (lower is better for class braces) # Calculate indentation level (lower is better for class braces)
indentation = len(line_content) - len(line_content.lstrip()) indentation = len(line_content) - len(line_content.lstrip())
# Prefer lower indentation (class braces are typically less indented than method braces) # Prefer lower indentation (class braces are typically less indented than method braces)
score += max(0, 20 - indentation) # Max 20 points for indentation=0 # Max 20 points for indentation=0
score += max(0, 20 - indentation)
# Prefer matches closer to end of file (class closing braces are typically at the end) # Prefer matches closer to end of file (class closing braces are typically at the end)
distance_from_end = len(lines) - line_num distance_from_end = len(lines) - line_num
score += max(0, 10 - distance_from_end) # More points for being closer to end # More points for being closer to end
score += max(0, 10 - distance_from_end)
# Look at surrounding context to avoid method braces # Look at surrounding context to avoid method braces
context_start = max(0, line_num - 3) context_start = max(0, line_num - 3)
context_end = min(len(lines), line_num + 2) context_end = min(len(lines), line_num + 2)
context_lines = lines[context_start:context_end] context_lines = lines[context_start:context_end]
# Penalize if this looks like it's inside a method (has method-like patterns above) # Penalize if this looks like it's inside a method (has method-like patterns above)
for context_line in context_lines: for context_line in context_lines:
if re.search(r'\b(void|public|private|protected)\s+\w+\s*\(', context_line): if re.search(r'\b(void|public|private|protected)\s+\w+\s*\(', context_line):
score -= 5 # Penalty for being near method signatures score -= 5 # Penalty for being near method signatures
# Bonus if this looks like a class-ending brace (very minimal indentation and near EOF) # Bonus if this looks like a class-ending brace (very minimal indentation and near EOF)
if indentation <= 4 and distance_from_end <= 3: if indentation <= 4 and distance_from_end <= 3:
score += 15 # Bonus for likely class-ending brace score += 15 # Bonus for likely class-ending brace
scored_matches.append((score, match)) scored_matches.append((score, match))
# Return the match with the highest score # Return the match with the highest score
scored_matches.sort(key=lambda x: x[0], reverse=True) scored_matches.sort(key=lambda x: x[0], reverse=True)
best_match = scored_matches[0][1] best_match = scored_matches[0][1]
return best_match return best_match
@ -209,8 +217,7 @@ def _extract_code_after(keyword: str, request: str) -> str:
# Removed _is_structurally_balanced - validation now handled by C# side using Unity's compiler services # Removed _is_structurally_balanced - validation now handled by C# side using Unity's compiler services
def _normalize_script_locator(name: str, path: str) -> tuple[str, str]:
def _normalize_script_locator(name: str, path: str) -> Tuple[str, str]:
"""Best-effort normalization of script "name" and "path". """Best-effort normalization of script "name" and "path".
Accepts any of: Accepts any of:
@ -258,7 +265,8 @@ def _normalize_script_locator(name: str, path: str) -> Tuple[str, str]:
parts = candidate.split("/") parts = candidate.split("/")
file_name = parts[-1] file_name = parts[-1]
dir_path = "/".join(parts[:-1]) if len(parts) > 1 else "Assets" dir_path = "/".join(parts[:-1]) if len(parts) > 1 else "Assets"
base = file_name[:-3] if file_name.lower().endswith(".cs") else file_name base = file_name[:-
3] if file_name.lower().endswith(".cs") else file_name
return base, dir_path return base, dir_path
# Fall back: remove extension from name if present and return given path # Fall back: remove extension from name if present and return given path
@ -266,7 +274,7 @@ def _normalize_script_locator(name: str, path: str) -> Tuple[str, str]:
return base_name, (p or "Assets") return base_name, (p or "Assets")
def _with_norm(resp: Dict[str, Any] | Any, edits: List[Dict[str, Any]], routing: str | None = None) -> Dict[str, Any] | Any: def _with_norm(resp: dict[str, Any] | Any, edits: list[dict[str, Any]], routing: str | None = None) -> dict[str, Any] | Any:
if not isinstance(resp, dict): if not isinstance(resp, dict):
return resp return resp
data = resp.setdefault("data", {}) data = resp.setdefault("data", {})
@ -276,10 +284,11 @@ def _with_norm(resp: Dict[str, Any] | Any, edits: List[Dict[str, Any]], routing:
return resp return resp
def _err(code: str, message: str, *, expected: Dict[str, Any] | None = None, rewrite: Dict[str, Any] | None = None, def _err(code: str, message: str, *, expected: dict[str, Any] | None = None, rewrite: dict[str, Any] | None = None,
normalized: List[Dict[str, Any]] | None = None, routing: str | None = None, extra: Dict[str, Any] | None = None) -> Dict[str, Any]: normalized: list[dict[str, Any]] | None = None, routing: str | None = None, extra: dict[str, Any] | None = None) -> dict[str, Any]:
payload: Dict[str, Any] = {"success": False, "code": code, "message": message} payload: dict[str, Any] = {"success": False,
data: Dict[str, Any] = {} "code": code, "message": message}
data: dict[str, Any] = {}
if expected: if expected:
data["expected"] = expected data["expected"] = expected
if rewrite: if rewrite:
@ -298,77 +307,78 @@ def _err(code: str, message: str, *, expected: Dict[str, Any] | None = None, rew
def register_manage_script_edits_tools(mcp: FastMCP): def register_manage_script_edits_tools(mcp: FastMCP):
@mcp.tool(description=( @mcp.tool(name="script_apply_edits", description=(
"Structured C# edits (methods/classes) with safer boundaries — prefer this over raw text.\n\n" """Structured C# edits (methods/classes) with safer boundaries - prefer this over raw text.
"Best practices:\n" Best practices:
"- Prefer anchor_* ops for pattern-based insert/replace near stable markers\n" - Prefer anchor_* ops for pattern-based insert/replace near stable markers
"- Use replace_method/delete_method for whole-method changes (keeps signatures balanced)\n" - Use replace_method/delete_method for whole-method changes (keeps signatures balanced)
"- Avoid whole-file regex deletes; validators will guard unbalanced braces\n" - Avoid whole-file regex deletes; validators will guard unbalanced braces
"- For tail insertions, prefer anchor/regex_replace on final brace (class closing)\n" - For tail insertions, prefer anchor/regex_replace on final brace (class closing)
"- Pass options.validate='standard' for structural checks; 'relaxed' for interior-only edits\n\n" - Pass options.validate='standard' for structural checks; 'relaxed' for interior-only edits
"Canonical fields (use these exact keys):\n" Canonical fields (use these exact keys):
"- op: replace_method | insert_method | delete_method | anchor_insert | anchor_delete | anchor_replace\n" - op: replace_method | insert_method | delete_method | anchor_insert | anchor_delete | anchor_replace
"- className: string (defaults to 'name' if omitted on method/class ops)\n" - className: string (defaults to 'name' if omitted on method/class ops)
"- methodName: string (required for replace_method, delete_method)\n" - methodName: string (required for replace_method, delete_method)
"- replacement: string (required for replace_method, insert_method)\n" - replacement: string (required for replace_method, insert_method)
"- position: start | end | after | before (insert_method only)\n" - position: start | end | after | before (insert_method only)
"- afterMethodName / beforeMethodName: string (required when position='after'/'before')\n" - afterMethodName / beforeMethodName: string (required when position='after'/'before')
"- anchor: regex string (for anchor_* ops)\n" - anchor: regex string (for anchor_* ops)
"- text: string (for anchor_insert/anchor_replace)\n\n" - text: string (for anchor_insert/anchor_replace)
"Do NOT use: new_method, anchor_method, content, newText (aliases accepted but normalized).\n\n" Examples:
"Examples:\n" 1) Replace a method:
"1) Replace a method:\n" {
"{\n" "name": "SmartReach",
" \"name\": \"SmartReach\",\n" "path": "Assets/Scripts/Interaction",
" \"path\": \"Assets/Scripts/Interaction\",\n" "edits": [
" \"edits\": [\n" {
" {\n" "op": "replace_method",
" \"op\": \"replace_method\",\n" "className": "SmartReach",
" \"className\": \"SmartReach\",\n" "methodName": "HasTarget",
" \"methodName\": \"HasTarget\",\n" "replacement": "public bool HasTarget(){ return currentTarget!=null; }"
" \"replacement\": \"public bool HasTarget(){ return currentTarget!=null; }\"\n" }
" }\n" ],
" ],\n" "options": {"validate": "standard", "refresh": "immediate"}
" \"options\": {\"validate\": \"standard\", \"refresh\": \"immediate\"}\n" }
"}\n\n" "2) Insert a method after another:
"2) Insert a method after another:\n" {
"{\n" "name": "SmartReach",
" \"name\": \"SmartReach\",\n" "path": "Assets/Scripts/Interaction",
" \"path\": \"Assets/Scripts/Interaction\",\n" "edits": [
" \"edits\": [\n" {
" {\n" "op": "insert_method",
" \"op\": \"insert_method\",\n" "className": "SmartReach",
" \"className\": \"SmartReach\",\n" "replacement": "public void PrintSeries(){ Debug.Log(seriesName); }",
" \"replacement\": \"public void PrintSeries(){ Debug.Log(seriesName); }\",\n" "position": "after",
" \"position\": \"after\",\n" "afterMethodName": "GetCurrentTarget"
" \"afterMethodName\": \"GetCurrentTarget\"\n" }
" }\n" ],
" ]\n" }
"}\n\n" ]"""
"Note: 'options' must be an object/dict, not a string. Use proper JSON syntax.\n"
)) ))
@telemetry_tool("script_apply_edits") @telemetry_tool("script_apply_edits")
def script_apply_edits( def script_apply_edits(
ctx: Context, ctx: Context,
name: str, name: Annotated[str, "Name of the script to edit"],
path: str, path: Annotated[str, "Path to the script to edit under Assets/ directory"],
edits: List[Dict[str, Any]], edits: Annotated[list[dict[str, Any]], "List of edits to apply to the script"],
options: Optional[Dict[str, Any]] = None, options: Annotated[dict[str, Any],
script_type: str = "MonoBehaviour", "Options for the script edit"] | None = None,
namespace: str = "", script_type: Annotated[str,
) -> Dict[str, Any]: "Type of the script to edit"] = "MonoBehaviour",
namespace: Annotated[str,
"Namespace of the script to edit"] | None = None,
) -> dict[str, Any]:
ctx.info(f"Processing script_apply_edits: {name}")
# Normalize locator first so downstream calls target the correct script file. # Normalize locator first so downstream calls target the correct script file.
name, path = _normalize_script_locator(name, path) name, path = _normalize_script_locator(name, path)
# No NL path: clients must provide structured edits in 'edits'.
# Normalize unsupported or aliased ops to known structured/text paths # Normalize unsupported or aliased ops to known structured/text paths
def _unwrap_and_alias(edit: Dict[str, Any]) -> Dict[str, Any]:
def _unwrap_and_alias(edit: dict[str, Any]) -> dict[str, Any]:
# Unwrap single-key wrappers like {"replace_method": {...}} # Unwrap single-key wrappers like {"replace_method": {...}}
for wrapper_key in ( for wrapper_key in (
"replace_method","insert_method","delete_method", "replace_method", "insert_method", "delete_method",
"replace_class","delete_class", "replace_class", "delete_class",
"anchor_insert","anchor_replace","anchor_delete", "anchor_insert", "anchor_replace", "anchor_delete",
): ):
if wrapper_key in edit and isinstance(edit[wrapper_key], dict): if wrapper_key in edit and isinstance(edit[wrapper_key], dict):
inner = dict(edit[wrapper_key]) inner = dict(edit[wrapper_key])
@ -377,7 +387,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
break break
e = dict(edit) e = dict(edit)
op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() op = (e.get("op") or e.get("operation") or e.get(
"type") or e.get("mode") or "").strip().lower()
if op: if op:
e["op"] = op e["op"] = op
@ -452,13 +463,14 @@ def register_manage_script_edits_tools(mcp: FastMCP):
e["text"] = edit.get("newText", "") e["text"] = edit.get("newText", "")
return e return e
normalized_edits: List[Dict[str, Any]] = [] normalized_edits: list[dict[str, Any]] = []
for raw in edits or []: for raw in edits or []:
e = _unwrap_and_alias(raw) e = _unwrap_and_alias(raw)
op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() op = (e.get("op") or e.get("operation") or e.get(
"type") or e.get("mode") or "").strip().lower()
# Default className to script name if missing on structured method/class ops # Default className to script name if missing on structured method/class ops
if op in ("replace_class","delete_class","replace_method","delete_method","insert_method") and not e.get("className"): if op in ("replace_class", "delete_class", "replace_method", "delete_method", "insert_method") and not e.get("className"):
e["className"] = name e["className"] = name
# Map common aliases for text ops # Map common aliases for text ops
@ -475,7 +487,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if "text" in e: if "text" in e:
e["replacement"] = e.get("text", "") e["replacement"] = e.get("text", "")
elif "insert" in e or "content" in e: elif "insert" in e or "content" in e:
e["replacement"] = e.get("insert") or e.get("content") or "" e["replacement"] = e.get(
"insert") or e.get("content") or ""
if op == "anchor_insert" and not (e.get("text") or e.get("insert") or e.get("content") or e.get("replacement")): if op == "anchor_insert" and not (e.get("text") or e.get("insert") or e.get("content") or e.get("replacement")):
e["op"] = "anchor_delete" e["op"] = "anchor_delete"
normalized_edits.append(e) normalized_edits.append(e)
@ -486,7 +499,7 @@ def register_manage_script_edits_tools(mcp: FastMCP):
normalized_for_echo = edits normalized_for_echo = edits
# Validate required fields and produce machine-parsable hints # Validate required fields and produce machine-parsable hints
def error_with_hint(message: str, expected: Dict[str, Any], suggestion: Dict[str, Any]) -> Dict[str, Any]: def error_with_hint(message: str, expected: dict[str, Any], suggestion: dict[str, Any]) -> dict[str, Any]:
return _err("missing_field", message, expected=expected, rewrite=suggestion, normalized=normalized_for_echo) return _err("missing_field", message, expected=expected, rewrite=suggestion, normalized=normalized_for_echo)
for e in edits or []: for e in edits or []:
@ -495,40 +508,46 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if not e.get("methodName"): if not e.get("methodName"):
return error_with_hint( return error_with_hint(
"replace_method requires 'methodName'.", "replace_method requires 'methodName'.",
{"op": "replace_method", "required": ["className", "methodName", "replacement"]}, {"op": "replace_method", "required": [
"className", "methodName", "replacement"]},
{"edits[0].methodName": "HasTarget"} {"edits[0].methodName": "HasTarget"}
) )
if not (e.get("replacement") or e.get("text")): if not (e.get("replacement") or e.get("text")):
return error_with_hint( return error_with_hint(
"replace_method requires 'replacement' (inline or base64).", "replace_method requires 'replacement' (inline or base64).",
{"op": "replace_method", "required": ["className", "methodName", "replacement"]}, {"op": "replace_method", "required": [
"className", "methodName", "replacement"]},
{"edits[0].replacement": "public bool X(){ return true; }"} {"edits[0].replacement": "public bool X(){ return true; }"}
) )
elif op == "insert_method": elif op == "insert_method":
if not (e.get("replacement") or e.get("text")): if not (e.get("replacement") or e.get("text")):
return error_with_hint( return error_with_hint(
"insert_method requires a non-empty 'replacement'.", "insert_method requires a non-empty 'replacement'.",
{"op": "insert_method", "required": ["className", "replacement"], "position": {"after_requires": "afterMethodName", "before_requires": "beforeMethodName"}}, {"op": "insert_method", "required": ["className", "replacement"], "position": {
"after_requires": "afterMethodName", "before_requires": "beforeMethodName"}},
{"edits[0].replacement": "public void PrintSeries(){ Debug.Log(\"1,2,3\"); }"} {"edits[0].replacement": "public void PrintSeries(){ Debug.Log(\"1,2,3\"); }"}
) )
pos = (e.get("position") or "").lower() pos = (e.get("position") or "").lower()
if pos == "after" and not e.get("afterMethodName"): if pos == "after" and not e.get("afterMethodName"):
return error_with_hint( return error_with_hint(
"insert_method with position='after' requires 'afterMethodName'.", "insert_method with position='after' requires 'afterMethodName'.",
{"op": "insert_method", "position": {"after_requires": "afterMethodName"}}, {"op": "insert_method", "position": {
"after_requires": "afterMethodName"}},
{"edits[0].afterMethodName": "GetCurrentTarget"} {"edits[0].afterMethodName": "GetCurrentTarget"}
) )
if pos == "before" and not e.get("beforeMethodName"): if pos == "before" and not e.get("beforeMethodName"):
return error_with_hint( return error_with_hint(
"insert_method with position='before' requires 'beforeMethodName'.", "insert_method with position='before' requires 'beforeMethodName'.",
{"op": "insert_method", "position": {"before_requires": "beforeMethodName"}}, {"op": "insert_method", "position": {
"before_requires": "beforeMethodName"}},
{"edits[0].beforeMethodName": "GetCurrentTarget"} {"edits[0].beforeMethodName": "GetCurrentTarget"}
) )
elif op == "delete_method": elif op == "delete_method":
if not e.get("methodName"): if not e.get("methodName"):
return error_with_hint( return error_with_hint(
"delete_method requires 'methodName'.", "delete_method requires 'methodName'.",
{"op": "delete_method", "required": ["className", "methodName"]}, {"op": "delete_method", "required": [
"className", "methodName"]},
{"edits[0].methodName": "PrintSeries"} {"edits[0].methodName": "PrintSeries"}
) )
elif op in ("anchor_insert", "anchor_replace", "anchor_delete"): elif op in ("anchor_insert", "anchor_replace", "anchor_delete"):
@ -546,9 +565,10 @@ def register_manage_script_edits_tools(mcp: FastMCP):
) )
# Decide routing: structured vs text vs mixed # Decide routing: structured vs text vs mixed
STRUCT = {"replace_class","delete_class","replace_method","delete_method","insert_method","anchor_delete","anchor_replace","anchor_insert"} STRUCT = {"replace_class", "delete_class", "replace_method", "delete_method",
TEXT = {"prepend","append","replace_range","regex_replace"} "insert_method", "anchor_delete", "anchor_replace", "anchor_insert"}
ops_set = { (e.get("op") or "").lower() for e in edits or [] } TEXT = {"prepend", "append", "replace_range", "regex_replace"}
ops_set = {(e.get("op") or "").lower() for e in edits or []}
all_struct = ops_set.issubset(STRUCT) all_struct = ops_set.issubset(STRUCT)
all_text = ops_set.issubset(TEXT) all_text = ops_set.issubset(TEXT)
mixed = not (all_struct or all_text) mixed = not (all_struct or all_text)
@ -558,7 +578,7 @@ def register_manage_script_edits_tools(mcp: FastMCP):
opts2 = dict(options or {}) opts2 = dict(options or {})
# For structured edits, prefer immediate refresh to avoid missed reloads when Editor is unfocused # For structured edits, prefer immediate refresh to avoid missed reloads when Editor is unfocused
opts2.setdefault("refresh", "immediate") opts2.setdefault("refresh", "immediate")
params_struct: Dict[str, Any] = { params_struct: dict[str, Any] = {
"action": "edit", "action": "edit",
"name": name, "name": name,
"path": path, "path": path,
@ -567,7 +587,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
"edits": edits, "edits": edits,
"options": opts2, "options": opts2,
} }
resp_struct = send_command_with_retry("manage_script", params_struct) resp_struct = send_command_with_retry(
"manage_script", params_struct)
if isinstance(resp_struct, dict) and resp_struct.get("success"): if isinstance(resp_struct, dict) and resp_struct.get("success"):
pass # Optional sentinel reload removed (deprecated) pass # Optional sentinel reload removed (deprecated)
return _with_norm(resp_struct if isinstance(resp_struct, dict) else {"success": False, "message": str(resp_struct)}, normalized_for_echo, routing="structured") return _with_norm(resp_struct if isinstance(resp_struct, dict) else {"success": False, "message": str(resp_struct)}, normalized_for_echo, routing="structured")
@ -583,10 +604,12 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if not isinstance(read_resp, dict) or not read_resp.get("success"): if not isinstance(read_resp, dict) or not read_resp.get("success"):
return read_resp if isinstance(read_resp, dict) else {"success": False, "message": str(read_resp)} return read_resp if isinstance(read_resp, dict) else {"success": False, "message": str(read_resp)}
data = read_resp.get("data") or read_resp.get("result", {}).get("data") or {} data = read_resp.get("data") or read_resp.get(
"result", {}).get("data") or {}
contents = data.get("contents") contents = data.get("contents")
if contents is None and data.get("contentsEncoded") and data.get("encodedContents"): if contents is None and data.get("contentsEncoded") and data.get("encodedContents"):
contents = base64.b64decode(data["encodedContents"]).decode("utf-8") contents = base64.b64decode(
data["encodedContents"]).decode("utf-8")
if contents is None: if contents is None:
return {"success": False, "message": "No contents returned from Unity read."} return {"success": False, "message": "No contents returned from Unity read."}
@ -595,28 +618,35 @@ def register_manage_script_edits_tools(mcp: FastMCP):
# If we have a mixed batch (TEXT + STRUCT), apply text first with precondition, then structured # If we have a mixed batch (TEXT + STRUCT), apply text first with precondition, then structured
if mixed: if mixed:
text_edits = [e for e in edits or [] if (e.get("op") or "").lower() in TEXT] text_edits = [e for e in edits or [] if (
struct_edits = [e for e in edits or [] if (e.get("op") or "").lower() in STRUCT] e.get("op") or "").lower() in TEXT]
struct_edits = [e for e in edits or [] if (
e.get("op") or "").lower() in STRUCT]
try: try:
base_text = contents base_text = contents
def line_col_from_index(idx: int) -> Tuple[int, int]:
def line_col_from_index(idx: int) -> tuple[int, int]:
line = base_text.count("\n", 0, idx) + 1 line = base_text.count("\n", 0, idx) + 1
last_nl = base_text.rfind("\n", 0, idx) last_nl = base_text.rfind("\n", 0, idx)
col = (idx - (last_nl + 1)) + 1 if last_nl >= 0 else idx + 1 col = (idx - (last_nl + 1)) + \
1 if last_nl >= 0 else idx + 1
return line, col return line, col
at_edits: List[Dict[str, Any]] = [] at_edits: list[dict[str, Any]] = []
import re as _re
for e in text_edits: for e in text_edits:
opx = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() opx = (e.get("op") or e.get("operation") or e.get(
text_field = e.get("text") or e.get("insert") or e.get("content") or e.get("replacement") or "" "type") or e.get("mode") or "").strip().lower()
text_field = e.get("text") or e.get("insert") or e.get(
"content") or e.get("replacement") or ""
if opx == "anchor_insert": if opx == "anchor_insert":
anchor = e.get("anchor") or "" anchor = e.get("anchor") or ""
position = (e.get("position") or "after").lower() position = (e.get("position") or "after").lower()
flags = _re.MULTILINE | (_re.IGNORECASE if e.get("ignore_case") else 0) flags = re.MULTILINE | (
re.IGNORECASE if e.get("ignore_case") else 0)
try: try:
# Use improved anchor matching logic # Use improved anchor matching logic
m = _find_best_anchor_match(anchor, base_text, flags, prefer_last=True) m = _find_best_anchor_match(
anchor, base_text, flags, prefer_last=True)
except Exception as ex: except Exception as ex:
return _with_norm(_err("bad_regex", f"Invalid anchor regex: {ex}", normalized=normalized_for_echo, routing="mixed/text-first", extra={"hint": "Escape parentheses/braces or use a simpler anchor."}), normalized_for_echo, routing="mixed/text-first") return _with_norm(_err("bad_regex", f"Invalid anchor regex: {ex}", normalized=normalized_for_echo, routing="mixed/text-first", extra={"hint": "Escape parentheses/braces or use a simpler anchor."}), normalized_for_echo, routing="mixed/text-first")
if not m: if not m:
@ -629,10 +659,11 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if not text_field_norm.endswith("\n"): if not text_field_norm.endswith("\n"):
text_field_norm = text_field_norm + "\n" text_field_norm = text_field_norm + "\n"
sl, sc = line_col_from_index(idx) sl, sc = line_col_from_index(idx)
at_edits.append({"startLine": sl, "startCol": sc, "endLine": sl, "endCol": sc, "newText": text_field_norm}) at_edits.append(
{"startLine": sl, "startCol": sc, "endLine": sl, "endCol": sc, "newText": text_field_norm})
# do not mutate base_text when building atomic spans # do not mutate base_text when building atomic spans
elif opx == "replace_range": elif opx == "replace_range":
if all(k in e for k in ("startLine","startCol","endLine","endCol")): if all(k in e for k in ("startLine", "startCol", "endLine", "endCol")):
at_edits.append({ at_edits.append({
"startLine": int(e.get("startLine", 1)), "startLine": int(e.get("startLine", 1)),
"startCol": int(e.get("startCol", 1)), "startCol": int(e.get("startCol", 1)),
@ -645,39 +676,44 @@ def register_manage_script_edits_tools(mcp: FastMCP):
elif opx == "regex_replace": elif opx == "regex_replace":
pattern = e.get("pattern") or "" pattern = e.get("pattern") or ""
try: try:
regex_obj = _re.compile(pattern, _re.MULTILINE | (_re.IGNORECASE if e.get("ignore_case") else 0)) regex_obj = re.compile(pattern, re.MULTILINE | (
re.IGNORECASE if e.get("ignore_case") else 0))
except Exception as ex: except Exception as ex:
return _with_norm(_err("bad_regex", f"Invalid regex pattern: {ex}", normalized=normalized_for_echo, routing="mixed/text-first", extra={"hint": "Escape special chars or prefer structured delete for methods."}), normalized_for_echo, routing="mixed/text-first") return _with_norm(_err("bad_regex", f"Invalid regex pattern: {ex}", normalized=normalized_for_echo, routing="mixed/text-first", extra={"hint": "Escape special chars or prefer structured delete for methods."}), normalized_for_echo, routing="mixed/text-first")
m = regex_obj.search(base_text) m = regex_obj.search(base_text)
if not m: if not m:
continue continue
# Expand $1, $2... in replacement using this match # Expand $1, $2... in replacement using this match
def _expand_dollars(rep: str, _m=m) -> str: def _expand_dollars(rep: str, _m=m) -> str:
return _re.sub(r"\$(\d+)", lambda g: _m.group(int(g.group(1))) or "", rep) return re.sub(r"\$(\d+)", lambda g: _m.group(int(g.group(1))) or "", rep)
repl = _expand_dollars(text_field) repl = _expand_dollars(text_field)
sl, sc = line_col_from_index(m.start()) sl, sc = line_col_from_index(m.start())
el, ec = line_col_from_index(m.end()) el, ec = line_col_from_index(m.end())
at_edits.append({"startLine": sl, "startCol": sc, "endLine": el, "endCol": ec, "newText": repl}) at_edits.append(
{"startLine": sl, "startCol": sc, "endLine": el, "endCol": ec, "newText": repl})
# do not mutate base_text when building atomic spans # do not mutate base_text when building atomic spans
elif opx in ("prepend","append"): elif opx in ("prepend", "append"):
if opx == "prepend": if opx == "prepend":
sl, sc = 1, 1 sl, sc = 1, 1
at_edits.append({"startLine": sl, "startCol": sc, "endLine": sl, "endCol": sc, "newText": text_field}) at_edits.append(
{"startLine": sl, "startCol": sc, "endLine": sl, "endCol": sc, "newText": text_field})
# prepend can be applied atomically without local mutation # prepend can be applied atomically without local mutation
else: else:
# Insert at true EOF position (handles both \n and \r\n correctly) # Insert at true EOF position (handles both \n and \r\n correctly)
eof_idx = len(base_text) eof_idx = len(base_text)
sl, sc = line_col_from_index(eof_idx) sl, sc = line_col_from_index(eof_idx)
new_text = ("\n" if not base_text.endswith("\n") else "") + text_field new_text = ("\n" if not base_text.endswith(
at_edits.append({"startLine": sl, "startCol": sc, "endLine": sl, "endCol": sc, "newText": new_text}) "\n") else "") + text_field
at_edits.append(
{"startLine": sl, "startCol": sc, "endLine": sl, "endCol": sc, "newText": new_text})
# do not mutate base_text when building atomic spans # do not mutate base_text when building atomic spans
else: else:
return _with_norm(_err("unknown_op", f"Unsupported text edit op: {opx}", normalized=normalized_for_echo, routing="mixed/text-first"), normalized_for_echo, routing="mixed/text-first") return _with_norm(_err("unknown_op", f"Unsupported text edit op: {opx}", normalized=normalized_for_echo, routing="mixed/text-first"), normalized_for_echo, routing="mixed/text-first")
import hashlib
sha = hashlib.sha256(base_text.encode("utf-8")).hexdigest() sha = hashlib.sha256(base_text.encode("utf-8")).hexdigest()
if at_edits: if at_edits:
params_text: Dict[str, Any] = { params_text: dict[str, Any] = {
"action": "apply_text_edits", "action": "apply_text_edits",
"name": name, "name": name,
"path": path, "path": path,
@ -687,7 +723,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
"precondition_sha256": sha, "precondition_sha256": sha,
"options": {"refresh": (options or {}).get("refresh", "debounced"), "validate": (options or {}).get("validate", "standard"), "applyMode": ("atomic" if len(at_edits) > 1 else (options or {}).get("applyMode", "sequential"))} "options": {"refresh": (options or {}).get("refresh", "debounced"), "validate": (options or {}).get("validate", "standard"), "applyMode": ("atomic" if len(at_edits) > 1 else (options or {}).get("applyMode", "sequential"))}
} }
resp_text = send_command_with_retry("manage_script", params_text) resp_text = send_command_with_retry(
"manage_script", params_text)
if not (isinstance(resp_text, dict) and resp_text.get("success")): if not (isinstance(resp_text, dict) and resp_text.get("success")):
return _with_norm(resp_text if isinstance(resp_text, dict) else {"success": False, "message": str(resp_text)}, normalized_for_echo, routing="mixed/text-first") return _with_norm(resp_text if isinstance(resp_text, dict) else {"success": False, "message": str(resp_text)}, normalized_for_echo, routing="mixed/text-first")
# Optional sentinel reload removed (deprecated) # Optional sentinel reload removed (deprecated)
@ -698,7 +735,7 @@ def register_manage_script_edits_tools(mcp: FastMCP):
opts2 = dict(options or {}) opts2 = dict(options or {})
# Prefer debounced background refresh unless explicitly overridden # Prefer debounced background refresh unless explicitly overridden
opts2.setdefault("refresh", "debounced") opts2.setdefault("refresh", "debounced")
params_struct: Dict[str, Any] = { params_struct: dict[str, Any] = {
"action": "edit", "action": "edit",
"name": name, "name": name,
"path": path, "path": path,
@ -707,7 +744,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
"edits": struct_edits, "edits": struct_edits,
"options": opts2 "options": opts2
} }
resp_struct = send_command_with_retry("manage_script", params_struct) resp_struct = send_command_with_retry(
"manage_script", params_struct)
if isinstance(resp_struct, dict) and resp_struct.get("success"): if isinstance(resp_struct, dict) and resp_struct.get("success"):
pass # Optional sentinel reload removed (deprecated) pass # Optional sentinel reload removed (deprecated)
return _with_norm(resp_struct if isinstance(resp_struct, dict) else {"success": False, "message": str(resp_struct)}, normalized_for_echo, routing="mixed/text-first") return _with_norm(resp_struct if isinstance(resp_struct, dict) else {"success": False, "message": str(resp_struct)}, normalized_for_echo, routing="mixed/text-first")
@ -717,32 +755,40 @@ def register_manage_script_edits_tools(mcp: FastMCP):
# If the edits are text-ops, prefer sending them to Unity's apply_text_edits with precondition # If the edits are text-ops, prefer sending them to Unity's apply_text_edits with precondition
# so header guards and validation run on the C# side. # so header guards and validation run on the C# side.
# Supported conversions: anchor_insert, replace_range, regex_replace (first match only). # Supported conversions: anchor_insert, replace_range, regex_replace (first match only).
text_ops = { (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() for e in (edits or []) } text_ops = {(e.get("op") or e.get("operation") or e.get("type") or e.get(
structured_kinds = {"replace_class","delete_class","replace_method","delete_method","insert_method","anchor_insert"} "mode") or "").strip().lower() for e in (edits or [])}
structured_kinds = {"replace_class", "delete_class",
"replace_method", "delete_method", "insert_method", "anchor_insert"}
if not text_ops.issubset(structured_kinds): if not text_ops.issubset(structured_kinds):
# Convert to apply_text_edits payload # Convert to apply_text_edits payload
try: try:
base_text = contents base_text = contents
def line_col_from_index(idx: int) -> Tuple[int, int]:
def line_col_from_index(idx: int) -> tuple[int, int]:
# 1-based line/col against base buffer # 1-based line/col against base buffer
line = base_text.count("\n", 0, idx) + 1 line = base_text.count("\n", 0, idx) + 1
last_nl = base_text.rfind("\n", 0, idx) last_nl = base_text.rfind("\n", 0, idx)
col = (idx - (last_nl + 1)) + 1 if last_nl >= 0 else idx + 1 col = (idx - (last_nl + 1)) + \
1 if last_nl >= 0 else idx + 1
return line, col return line, col
at_edits: List[Dict[str, Any]] = [] at_edits: list[dict[str, Any]] = []
import re as _re import re as _re
for e in edits or []: for e in edits or []:
op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() op = (e.get("op") or e.get("operation") or e.get(
"type") or e.get("mode") or "").strip().lower()
# aliasing for text field # aliasing for text field
text_field = e.get("text") or e.get("insert") or e.get("content") or "" text_field = e.get("text") or e.get(
"insert") or e.get("content") or ""
if op == "anchor_insert": if op == "anchor_insert":
anchor = e.get("anchor") or "" anchor = e.get("anchor") or ""
position = (e.get("position") or "after").lower() position = (e.get("position") or "after").lower()
# Use improved anchor matching logic with helpful errors, honoring ignore_case # Use improved anchor matching logic with helpful errors, honoring ignore_case
try: try:
flags = _re.MULTILINE | (_re.IGNORECASE if e.get("ignore_case") else 0) flags = re.MULTILINE | (
m = _find_best_anchor_match(anchor, base_text, flags, prefer_last=True) re.IGNORECASE if e.get("ignore_case") else 0)
m = _find_best_anchor_match(
anchor, base_text, flags, prefer_last=True)
except Exception as ex: except Exception as ex:
return _with_norm(_err("bad_regex", f"Invalid anchor regex: {ex}", normalized=normalized_for_echo, routing="text", extra={"hint": "Escape parentheses/braces or use a simpler anchor."}), normalized_for_echo, routing="text") return _with_norm(_err("bad_regex", f"Invalid anchor regex: {ex}", normalized=normalized_for_echo, routing="text", extra={"hint": "Escape parentheses/braces or use a simpler anchor."}), normalized_for_echo, routing="text")
if not m: if not m:
@ -778,19 +824,22 @@ def register_manage_script_edits_tools(mcp: FastMCP):
elif op == "regex_replace": elif op == "regex_replace":
pattern = e.get("pattern") or "" pattern = e.get("pattern") or ""
repl = text_field repl = text_field
flags = _re.MULTILINE | (_re.IGNORECASE if e.get("ignore_case") else 0) flags = re.MULTILINE | (
re.IGNORECASE if e.get("ignore_case") else 0)
# Early compile for clearer error messages # Early compile for clearer error messages
try: try:
regex_obj = _re.compile(pattern, flags) regex_obj = re.compile(pattern, flags)
except Exception as ex: except Exception as ex:
return _with_norm(_err("bad_regex", f"Invalid regex pattern: {ex}", normalized=normalized_for_echo, routing="text", extra={"hint": "Escape special chars or prefer structured delete for methods."}), normalized_for_echo, routing="text") return _with_norm(_err("bad_regex", f"Invalid regex pattern: {ex}", normalized=normalized_for_echo, routing="text", extra={"hint": "Escape special chars or prefer structured delete for methods."}), normalized_for_echo, routing="text")
# Use smart anchor matching for consistent behavior with anchor_insert # Use smart anchor matching for consistent behavior with anchor_insert
m = _find_best_anchor_match(pattern, base_text, flags, prefer_last=True) m = _find_best_anchor_match(
pattern, base_text, flags, prefer_last=True)
if not m: if not m:
continue continue
# Expand $1, $2... backrefs in replacement using the first match (consistent with mixed-path behavior) # Expand $1, $2... backrefs in replacement using the first match (consistent with mixed-path behavior)
def _expand_dollars(rep: str, _m=m) -> str: def _expand_dollars(rep: str, _m=m) -> str:
return _re.sub(r"\$(\d+)", lambda g: _m.group(int(g.group(1))) or "", rep) return re.sub(r"\$(\d+)", lambda g: _m.group(int(g.group(1))) or "", rep)
repl_expanded = _expand_dollars(repl) repl_expanded = _expand_dollars(repl)
# Let C# side handle validation using Unity's built-in compiler services # Let C# side handle validation using Unity's built-in compiler services
sl, sc = line_col_from_index(m.start()) sl, sc = line_col_from_index(m.start())
@ -809,10 +858,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if not at_edits: if not at_edits:
return _with_norm({"success": False, "code": "no_spans", "message": "No applicable text edit spans computed (anchor not found or zero-length)."}, normalized_for_echo, routing="text") return _with_norm({"success": False, "code": "no_spans", "message": "No applicable text edit spans computed (anchor not found or zero-length)."}, normalized_for_echo, routing="text")
# Send to Unity with precondition SHA to enforce guards and immediate refresh
import hashlib
sha = hashlib.sha256(base_text.encode("utf-8")).hexdigest() sha = hashlib.sha256(base_text.encode("utf-8")).hexdigest()
params: Dict[str, Any] = { params: dict[str, Any] = {
"action": "apply_text_edits", "action": "apply_text_edits",
"name": name, "name": name,
"path": path, "path": path,
@ -830,7 +877,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if isinstance(resp, dict) and resp.get("success"): if isinstance(resp, dict) and resp.get("success"):
pass # Optional sentinel reload removed (deprecated) pass # Optional sentinel reload removed (deprecated)
return _with_norm( return _with_norm(
resp if isinstance(resp, dict) else {"success": False, "message": str(resp)}, resp if isinstance(resp, dict) else {
"success": False, "message": str(resp)},
normalized_for_echo, normalized_for_echo,
routing="text" routing="text"
) )
@ -843,7 +891,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
try: try:
preview_text = _apply_edits_locally(contents, edits) preview_text = _apply_edits_locally(contents, edits)
import difflib import difflib
diff = list(difflib.unified_diff(contents.splitlines(), preview_text.splitlines(), fromfile="before", tofile="after", n=2)) diff = list(difflib.unified_diff(contents.splitlines(
), preview_text.splitlines(), fromfile="before", tofile="after", n=2))
if len(diff) > 800: if len(diff) > 800:
diff = diff[:800] + ["... (diff truncated) ..."] diff = diff[:800] + ["... (diff truncated) ..."]
if preview: if preview:
@ -870,7 +919,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
import difflib import difflib
a = contents.splitlines() a = contents.splitlines()
b = new_contents.splitlines() b = new_contents.splitlines()
diff = list(difflib.unified_diff(a, b, fromfile="before", tofile="after", n=3)) diff = list(difflib.unified_diff(
a, b, fromfile="before", tofile="after", n=3))
# Limit diff size to keep responses small # Limit diff size to keep responses small
if len(diff) > 2000: if len(diff) > 2000:
diff = diff[:2000] + ["... (diff truncated) ..."] diff = diff[:2000] + ["... (diff truncated) ..."]
@ -882,7 +932,6 @@ def register_manage_script_edits_tools(mcp: FastMCP):
options.setdefault("validate", "standard") options.setdefault("validate", "standard")
options.setdefault("refresh", "debounced") options.setdefault("refresh", "debounced")
import hashlib
# Compute the SHA of the current file contents for the precondition # Compute the SHA of the current file contents for the precondition
old_lines = contents.splitlines(keepends=True) old_lines = contents.splitlines(keepends=True)
end_line = len(old_lines) + 1 # 1-based exclusive end end_line = len(old_lines) + 1 # 1-based exclusive end
@ -912,13 +961,8 @@ def register_manage_script_edits_tools(mcp: FastMCP):
if isinstance(write_resp, dict) and write_resp.get("success"): if isinstance(write_resp, dict) and write_resp.get("success"):
pass # Optional sentinel reload removed (deprecated) pass # Optional sentinel reload removed (deprecated)
return _with_norm( return _with_norm(
write_resp if isinstance(write_resp, dict) write_resp if isinstance(write_resp, dict)
else {"success": False, "message": str(write_resp)}, else {"success": False, "message": str(write_resp)},
normalized_for_echo, normalized_for_echo,
routing="text", routing="text",
) )
# safe_script_edit removed to simplify API; clients should call script_apply_edits directly

View File

@ -1,36 +1,26 @@
from mcp.server.fastmcp import FastMCP, Context
from typing import Dict, Any
from unity_connection import get_unity_connection, send_command_with_retry
from config import config
import time
import os
import base64 import base64
from typing import Annotated, Any, Literal
from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry
def register_manage_shader_tools(mcp: FastMCP): def register_manage_shader_tools(mcp: FastMCP):
"""Register all shader script management tools with the MCP server.""" """Register all shader script management tools with the MCP server."""
@mcp.tool() @mcp.tool(name="manage_shader", description="Manages shader scripts in Unity (create, read, update, delete).")
@telemetry_tool("manage_shader") @telemetry_tool("manage_shader")
def manage_shader( def manage_shader(
ctx: Any, ctx: Context,
action: str, action: Annotated[Literal['create', 'read', 'update', 'delete'], "Perform CRUD operations on shader scripts."],
name: str, name: Annotated[str, "Shader name (no .cs extension)"],
path: str, path: Annotated[str, "Asset path (default: \"Assets/\")"],
contents: str, contents: Annotated[str,
) -> Dict[str, Any]: "Shader code for 'create'/'update'"] | None = None,
"""Manages shader scripts in Unity (create, read, update, delete). ) -> dict[str, Any]:
ctx.info(f"Processing manage_shader: {action}")
Args:
action: Operation ('create', 'read', 'update', 'delete').
name: Shader name (no .cs extension).
path: Asset path (default: "Assets/").
contents: Shader code for 'create'/'update'.
Returns:
Dictionary with results ('success', 'message', 'data').
"""
try: try:
# Prepare parameters for Unity # Prepare parameters for Unity
params = { params = {
@ -38,34 +28,36 @@ def register_manage_shader_tools(mcp: FastMCP):
"name": name, "name": name,
"path": path, "path": path,
} }
# Base64 encode the contents if they exist to avoid JSON escaping issues # Base64 encode the contents if they exist to avoid JSON escaping issues
if contents is not None: if contents is not None:
if action in ['create', 'update']: if action in ['create', 'update']:
# Encode content for safer transmission # Encode content for safer transmission
params["encodedContents"] = base64.b64encode(contents.encode('utf-8')).decode('utf-8') params["encodedContents"] = base64.b64encode(
contents.encode('utf-8')).decode('utf-8')
params["contentsEncoded"] = True params["contentsEncoded"] = True
else: else:
params["contents"] = contents params["contents"] = contents
# Remove None values so they don't get sent as null # Remove None values so they don't get sent as null
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
# Send command via centralized retry helper # Send command via centralized retry helper
response = send_command_with_retry("manage_shader", params) response = send_command_with_retry("manage_shader", params)
# Process response from Unity # Process response from Unity
if isinstance(response, dict) and response.get("success"): if isinstance(response, dict) and response.get("success"):
# If the response contains base64 encoded content, decode it # If the response contains base64 encoded content, decode it
if response.get("data", {}).get("contentsEncoded"): if response.get("data", {}).get("contentsEncoded"):
decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8') decoded_contents = base64.b64decode(
response["data"]["encodedContents"]).decode('utf-8')
response["data"]["contents"] = decoded_contents response["data"]["contents"] = decoded_contents
del response["data"]["encodedContents"] del response["data"]["encodedContents"]
del response["data"]["contentsEncoded"] del response["data"]["contentsEncoded"]
return {"success": True, "message": response.get("message", "Operation successful."), "data": response.get("data")} return {"success": True, "message": response.get("message", "Operation successful."), "data": response.get("data")}
return response if isinstance(response, dict) else {"success": False, "message": str(response)} return response if isinstance(response, dict) else {"success": False, "message": str(response)}
except Exception as e: except Exception as e:
# Handle Python-side errors (e.g., connection issues) # Handle Python-side errors (e.g., connection issues)
return {"success": False, "message": f"Python error managing shader: {str(e)}"} return {"success": False, "message": f"Python error managing shader: {str(e)}"}

View File

@ -1,47 +1,34 @@
""" """
Defines the read_console tool for accessing Unity Editor console messages. Defines the read_console tool for accessing Unity Editor console messages.
""" """
from typing import List, Dict, Any from typing import Annotated, Any, Literal
import time
from mcp.server.fastmcp import FastMCP, Context from mcp.server.fastmcp import FastMCP, Context
from unity_connection import get_unity_connection, send_command_with_retry
from config import config
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry
def register_read_console_tools(mcp: FastMCP): def register_read_console_tools(mcp: FastMCP):
"""Registers the read_console tool with the MCP server.""" """Registers the read_console tool with the MCP server."""
@mcp.tool() @mcp.tool(name="read_console", description="Gets messages from or clears the Unity Editor console.")
@telemetry_tool("read_console") @telemetry_tool("read_console")
def read_console( def read_console(
ctx: Context, ctx: Context,
action: str = None, action: Annotated[Literal['get', 'clear'], "Get or clear the Unity Editor console."],
types: List[str] = None, types: Annotated[list[Literal['error', 'warning',
count: Any = None, 'log', 'all']], "Message types to get"] | None = None,
filter_text: str = None, count: Annotated[int, "Max messages to return"] | None = None,
since_timestamp: str = None, filter_text: Annotated[str, "Text filter for messages"] | None = None,
format: str = None, since_timestamp: Annotated[str,
include_stacktrace: bool = None "Get messages after this timestamp (ISO 8601)"] | None = None,
) -> Dict[str, Any]: format: Annotated[Literal['plain', 'detailed',
"""Gets messages from or clears the Unity Editor console. 'json'], "Output format"] | None = None,
include_stacktrace: Annotated[bool,
Args: "Include stack traces in output"] | None = None
ctx: The MCP context. ) -> dict[str, Any]:
action: Operation ('get' or 'clear'). ctx.info(f"Processing read_console: {action}")
types: Message types to get ('error', 'warning', 'log', 'all').
count: Max messages to return.
filter_text: Text filter for messages.
since_timestamp: Get messages after this timestamp (ISO 8601).
format: Output format ('plain', 'detailed', 'json').
include_stacktrace: Include stack traces in output.
Returns:
Dictionary with results. For 'get', includes 'data' (messages).
"""
# Get the connection instance
bridge = get_unity_connection()
# Set defaults if values are None # Set defaults if values are None
action = action if action is not None else 'get' action = action if action is not None else 'get'
types = types if types is not None else ['error', 'warning', 'log'] types = types if types is not None else ['error', 'warning', 'log']
@ -51,7 +38,7 @@ def register_read_console_tools(mcp: FastMCP):
# Normalize action if it's a string # Normalize action if it's a string
if isinstance(action, str): if isinstance(action, str):
action = action.lower() action = action.lower()
# Coerce count defensively (string/float -> int) # Coerce count defensively (string/float -> int)
def _coerce_int(value, default=None): def _coerce_int(value, default=None):
if value is None: if value is None:
@ -82,11 +69,12 @@ def register_read_console_tools(mcp: FastMCP):
} }
# Remove None values unless it's 'count' (as None might mean 'all') # Remove None values unless it's 'count' (as None might mean 'all')
params_dict = {k: v for k, v in params_dict.items() if v is not None or k == 'count'} params_dict = {k: v for k, v in params_dict.items()
if v is not None or k == 'count'}
# Add count back if it was None, explicitly sending null might be important for C# logic # Add count back if it was None, explicitly sending null might be important for C# logic
if 'count' not in params_dict: if 'count' not in params_dict:
params_dict['count'] = None params_dict['count'] = None
# Use centralized retry helper # Use centralized retry helper
resp = send_command_with_retry("read_console", params_dict) resp = send_command_with_retry("read_console", params_dict)
@ -99,4 +87,4 @@ def register_read_console_tools(mcp: FastMCP):
line.pop("stacktrace", None) line.pop("stacktrace", None)
except Exception: except Exception:
pass pass
return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)}

View File

@ -3,21 +3,21 @@ Resource wrapper tools so clients that do not expose MCP resources primitives
can still list and read files via normal tools. These call into the same can still list and read files via normal tools. These call into the same
safe path logic (re-implemented here to avoid importing server.py). safe path logic (re-implemented here to avoid importing server.py).
""" """
from typing import Dict, Any, List, Optional
import re
from pathlib import Path
from urllib.parse import urlparse, unquote
import fnmatch import fnmatch
import hashlib import hashlib
import os import os
from pathlib import Path
import re
from typing import Annotated, Any
from urllib.parse import urlparse, unquote
from mcp.server.fastmcp import FastMCP, Context from mcp.server.fastmcp import FastMCP, Context
from telemetry_decorator import telemetry_tool from telemetry_decorator import telemetry_tool
from unity_connection import send_command_with_retry from unity_connection import send_command_with_retry
def _coerce_int(value: Any, default: Optional[int] = None, minimum: Optional[int] = None) -> Optional[int]: def _coerce_int(value: Any, default: int | None = None, minimum: int | None = None) -> int | None:
"""Safely coerce various inputs (str/float/etc.) to an int. """Safely coerce various inputs (str/float/etc.) to an int.
Returns default on failure; clamps to minimum when provided. Returns default on failure; clamps to minimum when provided.
""" """
@ -41,6 +41,7 @@ def _coerce_int(value: Any, default: Optional[int] = None, minimum: Optional[int
except Exception: except Exception:
return default return default
def _resolve_project_root(override: str | None) -> Path: def _resolve_project_root(override: str | None) -> Path:
# 1) Explicit override # 1) Explicit override
if override: if override:
@ -52,14 +53,17 @@ def _resolve_project_root(override: str | None) -> Path:
if env: if env:
env_path = Path(env).expanduser() env_path = Path(env).expanduser()
# If UNITY_PROJECT_ROOT is relative, resolve against repo root (cwd's repo) instead of src dir # If UNITY_PROJECT_ROOT is relative, resolve against repo root (cwd's repo) instead of src dir
pr = (Path.cwd() / env_path).resolve() if not env_path.is_absolute() else env_path.resolve() pr = (Path.cwd(
) / env_path).resolve() if not env_path.is_absolute() else env_path.resolve()
if (pr / "Assets").exists(): if (pr / "Assets").exists():
return pr return pr
# 3) Ask Unity via manage_editor.get_project_root # 3) Ask Unity via manage_editor.get_project_root
try: try:
resp = send_command_with_retry("manage_editor", {"action": "get_project_root"}) resp = send_command_with_retry(
"manage_editor", {"action": "get_project_root"})
if isinstance(resp, dict) and resp.get("success"): if isinstance(resp, dict) and resp.get("success"):
pr = Path(resp.get("data", {}).get("projectRoot", "")).expanduser().resolve() pr = Path(resp.get("data", {}).get(
"projectRoot", "")).expanduser().resolve()
if pr and (pr / "Assets").exists(): if pr and (pr / "Assets").exists():
return pr return pr
except Exception: except Exception:
@ -132,26 +136,17 @@ def _resolve_safe_path_from_uri(uri: str, project: Path) -> Path | None:
def register_resource_tools(mcp: FastMCP) -> None: def register_resource_tools(mcp: FastMCP) -> None:
"""Registers list_resources and read_resource wrapper tools.""" """Registers list_resources and read_resource wrapper tools."""
@mcp.tool(description=( @mcp.tool(name="list_resources", description=("List project URIs (unity://path/...) under a folder (default: Assets). Only .cs files are returned by default; always appends unity://spec/script-edits.\n"))
"List project URIs (unity://path/...) under a folder (default: Assets).\n\n"
"Args: pattern (glob, default *.cs), under (folder under project root), limit, project_root.\n"
"Security: restricted to Assets/ subtree; symlinks are resolved and must remain under Assets/.\n"
"Notes: Only .cs files are returned by default; always appends unity://spec/script-edits.\n"
))
@telemetry_tool("list_resources") @telemetry_tool("list_resources")
async def list_resources( async def list_resources(
ctx: Optional[Context] = None, ctx: Context,
pattern: Optional[str] = "*.cs", pattern: Annotated[str, "Glob, default is *.cs"] | None = "*.cs",
under: str = "Assets", under: Annotated[str,
limit: Any = 200, "Folder under project root, default is Assets"] = "Assets",
project_root: Optional[str] = None, limit: Annotated[int, "Page limit"] = 200,
) -> Dict[str, Any]: project_root: Annotated[str, "Project path"] | None = None,
""" ) -> dict[str, Any]:
Lists project URIs (unity://path/...) under a folder (default: Assets). ctx.info(f"Processing list_resources: {pattern}")
- pattern: glob like *.cs or *.shader (None to list all files)
- under: relative folder under project root
- limit: max results
"""
try: try:
project = _resolve_project_root(project_root) project = _resolve_project_root(project_root)
base = (project / under).resolve() base = (project / under).resolve()
@ -165,7 +160,7 @@ def register_resource_tools(mcp: FastMCP) -> None:
except ValueError: except ValueError:
return {"success": False, "error": "Listing is restricted to Assets/"} return {"success": False, "error": "Listing is restricted to Assets/"}
matches: List[str] = [] matches: list[str] = []
limit_int = _coerce_int(limit, default=200, minimum=1) limit_int = _coerce_int(limit, default=200, minimum=1)
for p in base.rglob("*"): for p in base.rglob("*"):
if not p.is_file(): if not p.is_file():
@ -194,33 +189,30 @@ def register_resource_tools(mcp: FastMCP) -> None:
except Exception as e: except Exception as e:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@mcp.tool(description=( @mcp.tool(name="read_resource", description=("Reads a resource by unity://path/... URI with optional slicing."))
"Read a resource by unity://path/... URI with optional slicing.\n\n"
"Args: uri, start_line/line_count or head_bytes, tail_lines (optional), project_root, request (NL hints).\n"
"Security: uri must resolve under Assets/.\n"
"Examples: head_bytes=1024; start_line=100,line_count=40; tail_lines=120.\n"
))
@telemetry_tool("read_resource") @telemetry_tool("read_resource")
async def read_resource( async def read_resource(
uri: str, ctx: Context,
ctx: Optional[Context] = None, uri: Annotated[str, "The resource URI to read under Assets/"],
start_line: Any = None, start_line: Annotated[int,
line_count: Any = None, "The starting line number (0-based)"] | None = None,
head_bytes: Any = None, line_count: Annotated[int,
tail_lines: Any = None, "The number of lines to read"] | None = None,
project_root: Optional[str] = None, head_bytes: Annotated[int,
request: Optional[str] = None, "The number of bytes to read from the start of the file"] | None = None,
) -> Dict[str, Any]: tail_lines: Annotated[int,
""" "The number of lines to read from the end of the file"] | None = None,
Reads a resource by unity://path/... URI with optional slicing. project_root: Annotated[str,
One of line window (start_line/line_count) or head_bytes can be used to limit size. "The project root directory"] | None = None,
""" request: Annotated[str, "The request ID"] | None = None,
) -> dict[str, Any]:
ctx.info(f"Processing read_resource: {uri}")
try: try:
# Serve the canonical spec directly when requested (allow bare or with scheme) # Serve the canonical spec directly when requested (allow bare or with scheme)
if uri in ("unity://spec/script-edits", "spec/script-edits", "script-edits"): if uri in ("unity://spec/script-edits", "spec/script-edits", "script-edits"):
spec_json = ( spec_json = (
'{\n' '{\n'
' "name": "Unity MCP Script Edits v1",\n' ' "name": "Unity MCP - Script Edits v1",\n'
' "target_tool": "script_apply_edits",\n' ' "target_tool": "script_apply_edits",\n'
' "canonical_rules": {\n' ' "canonical_rules": {\n'
' "always_use": ["op","className","methodName","replacement","afterMethodName","beforeMethodName"],\n' ' "always_use": ["op","className","methodName","replacement","afterMethodName","beforeMethodName"],\n'
@ -300,14 +292,16 @@ def register_resource_tools(mcp: FastMCP) -> None:
m = re.search(r"first\s+(\d+)\s*bytes", req) m = re.search(r"first\s+(\d+)\s*bytes", req)
if m: if m:
head_bytes = int(m.group(1)) head_bytes = int(m.group(1))
m = re.search(r"show\s+(\d+)\s+lines\s+around\s+([A-Za-z_][A-Za-z0-9_]*)", req) m = re.search(
r"show\s+(\d+)\s+lines\s+around\s+([A-Za-z_][A-Za-z0-9_]*)", req)
if m: if m:
window = int(m.group(1)) window = int(m.group(1))
method = m.group(2) method = m.group(2)
# naive search for method header to get a line number # naive search for method header to get a line number
text_all = p.read_text(encoding="utf-8") text_all = p.read_text(encoding="utf-8")
lines_all = text_all.splitlines() lines_all = text_all.splitlines()
pat = re.compile(rf"^\s*(?:\[[^\]]+\]\s*)*(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial).*?\b{re.escape(method)}\s*\(", re.MULTILINE) pat = re.compile(
rf"^\s*(?:\[[^\]]+\]\s*)*(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial).*?\b{re.escape(method)}\s*\(", re.MULTILINE)
hit_line = None hit_line = None
for i, line in enumerate(lines_all, start=1): for i, line in enumerate(lines_all, start=1):
if pat.search(line): if pat.search(line):
@ -329,7 +323,8 @@ def register_resource_tools(mcp: FastMCP) -> None:
full_sha = hashlib.sha256(full_bytes).hexdigest() full_sha = hashlib.sha256(full_bytes).hexdigest()
# Selection only when explicitly requested via windowing args or request text hints # Selection only when explicitly requested via windowing args or request text hints
selection_requested = bool(head_bytes or tail_lines or (start_line is not None and line_count is not None) or request) selection_requested = bool(head_bytes or tail_lines or (
start_line is not None and line_count is not None) or request)
if selection_requested: if selection_requested:
# Mutually exclusive windowing options precedence: # Mutually exclusive windowing options precedence:
# 1) head_bytes, 2) tail_lines, 3) start_line+line_count, else full text # 1) head_bytes, 2) tail_lines, 3) start_line+line_count, else full text
@ -354,24 +349,19 @@ def register_resource_tools(mcp: FastMCP) -> None:
except Exception as e: except Exception as e:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@mcp.tool() @mcp.tool(name="find_in_file", description="Searches a file with a regex pattern and returns line numbers and excerpts.")
@telemetry_tool("find_in_file") @telemetry_tool("find_in_file")
async def find_in_file( async def find_in_file(
uri: str, ctx: Context,
pattern: str, uri: Annotated[str, "The resource URI to search under Assets/ or file path form supported by read_resource"],
ctx: Optional[Context] = None, pattern: Annotated[str, "The regex pattern to search for"],
ignore_case: Optional[bool] = True, ignore_case: Annotated[bool, "Case-insensitive search"] | None = True,
project_root: Optional[str] = None, project_root: Annotated[str,
max_results: Any = 200, "The project root directory"] | None = None,
) -> Dict[str, Any]: max_results: Annotated[int,
""" "Cap results to avoid huge payloads"] = 200,
Searches a file with a regex pattern and returns line numbers and excerpts. ) -> dict[str, Any]:
- uri: unity://path/Assets/... or file path form supported by read_resource ctx.info(f"Processing find_in_file: {uri}")
- pattern: regular expression (Python re)
- ignore_case: case-insensitive by default
- max_results: cap results to avoid huge payloads
"""
# re is already imported at module level
try: try:
project = _resolve_project_root(project_root) project = _resolve_project_root(project_root)
p = _resolve_safe_path_from_uri(uri, project) p = _resolve_safe_path_from_uri(uri, project)
@ -404,5 +394,3 @@ def register_resource_tools(mcp: FastMCP) -> None:
return {"success": True, "data": {"matches": results, "count": len(results)}} return {"success": True, "data": {"matches": results, "count": len(results)}}
except Exception as e: except Exception as e:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}

View File

@ -1,17 +1,18 @@
from config import config
import contextlib import contextlib
from dataclasses import dataclass
import errno import errno
import json import json
import logging import logging
from pathlib import Path
from port_discovery import PortDiscovery
import random import random
import socket import socket
import struct import struct
import threading import threading
import time import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
from config import config
from port_discovery import PortDiscovery
# Configure logging using settings from config # Configure logging using settings from config
logging.basicConfig( logging.basicConfig(
@ -26,6 +27,7 @@ _connection_lock = threading.Lock()
# Maximum allowed framed payload size (64 MiB) # Maximum allowed framed payload size (64 MiB)
FRAMED_MAX = 64 * 1024 * 1024 FRAMED_MAX = 64 * 1024 * 1024
@dataclass @dataclass
class UnityConnection: class UnityConnection:
"""Manages the socket connection to the Unity Editor.""" """Manages the socket connection to the Unity Editor."""
@ -33,7 +35,7 @@ class UnityConnection:
port: int = None # Will be set dynamically port: int = None # Will be set dynamically
sock: socket.socket = None # Socket for Unity communication sock: socket.socket = None # Socket for Unity communication
use_framing: bool = False # Negotiated per-connection use_framing: bool = False # Negotiated per-connection
def __post_init__(self): def __post_init__(self):
"""Set port from discovery if not explicitly provided""" """Set port from discovery if not explicitly provided"""
if self.port is None: if self.port is None:
@ -50,11 +52,14 @@ class UnityConnection:
return True return True
try: try:
# Bounded connect to avoid indefinite blocking # Bounded connect to avoid indefinite blocking
connect_timeout = float(getattr(config, "connect_timeout", getattr(config, "connection_timeout", 1.0))) connect_timeout = float(
self.sock = socket.create_connection((self.host, self.port), connect_timeout) getattr(config, "connect_timeout", getattr(config, "connection_timeout", 1.0)))
self.sock = socket.create_connection(
(self.host, self.port), connect_timeout)
# Disable Nagle's algorithm to reduce small RPC latency # Disable Nagle's algorithm to reduce small RPC latency
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
logger.debug(f"Connected to Unity at {self.host}:{self.port}") logger.debug(f"Connected to Unity at {self.host}:{self.port}")
# Strict handshake: require FRAMING=1 # Strict handshake: require FRAMING=1
@ -78,16 +83,20 @@ class UnityConnection:
if 'FRAMING=1' in text: if 'FRAMING=1' in text:
self.use_framing = True self.use_framing = True
logger.debug('Unity MCP handshake received: FRAMING=1 (strict)') logger.debug(
'Unity MCP handshake received: FRAMING=1 (strict)')
else: else:
if require_framing: if require_framing:
# Best-effort plain-text advisory for legacy peers # Best-effort plain-text advisory for legacy peers
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.sock.sendall(b'Unity MCP requires FRAMING=1\n') self.sock.sendall(
raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}') b'Unity MCP requires FRAMING=1\n')
raise ConnectionError(
f'Unity MCP requires FRAMING=1, got: {text!r}')
else: else:
self.use_framing = False self.use_framing = False
logger.warning('Unity MCP handshake missing FRAMING=1; proceeding in legacy mode by configuration') logger.warning(
'Unity MCP handshake missing FRAMING=1; proceeding in legacy mode by configuration')
finally: finally:
self.sock.settimeout(config.connection_timeout) self.sock.settimeout(config.connection_timeout)
return True return True
@ -116,7 +125,8 @@ class UnityConnection:
while len(data) < count: while len(data) < count:
chunk = sock.recv(count - len(data)) chunk = sock.recv(count - len(data))
if not chunk: if not chunk:
raise ConnectionError("Connection closed before reading expected bytes") raise ConnectionError(
"Connection closed before reading expected bytes")
data.extend(chunk) data.extend(chunk)
return bytes(data) return bytes(data)
@ -136,13 +146,16 @@ class UnityConnection:
heartbeat_count += 1 heartbeat_count += 1
if heartbeat_count >= getattr(config, 'max_heartbeat_frames', 16) or time.monotonic() > deadline: if heartbeat_count >= getattr(config, 'max_heartbeat_frames', 16) or time.monotonic() > deadline:
# Treat as empty successful response to match C# server behavior # Treat as empty successful response to match C# server behavior
logger.debug("Heartbeat threshold reached; returning empty response") logger.debug(
"Heartbeat threshold reached; returning empty response")
return b"" return b""
continue continue
if payload_len > FRAMED_MAX: if payload_len > FRAMED_MAX:
raise ValueError(f"Invalid framed length: {payload_len}") raise ValueError(
f"Invalid framed length: {payload_len}")
payload = self._read_exact(sock, payload_len) payload = self._read_exact(sock, payload_len)
logger.debug(f"Received framed response ({len(payload)} bytes)") logger.debug(
f"Received framed response ({len(payload)} bytes)")
return payload return payload
except socket.timeout as e: except socket.timeout as e:
logger.warning("Socket timeout during framed receive") logger.warning("Socket timeout during framed receive")
@ -158,21 +171,22 @@ class UnityConnection:
chunk = sock.recv(buffer_size) chunk = sock.recv(buffer_size)
if not chunk: if not chunk:
if not chunks: if not chunks:
raise Exception("Connection closed before receiving data") raise Exception(
"Connection closed before receiving data")
break break
chunks.append(chunk) chunks.append(chunk)
# Process the data received so far # Process the data received so far
data = b''.join(chunks) data = b''.join(chunks)
decoded_data = data.decode('utf-8') decoded_data = data.decode('utf-8')
# Check if we've received a complete response # Check if we've received a complete response
try: try:
# Special case for ping-pong # Special case for ping-pong
if decoded_data.strip().startswith('{"status":"success","result":{"message":"pong"'): if decoded_data.strip().startswith('{"status":"success","result":{"message":"pong"'):
logger.debug("Received ping response") logger.debug("Received ping response")
return data return data
# Handle escaped quotes in the content # Handle escaped quotes in the content
if '"content":' in decoded_data: if '"content":' in decoded_data:
# Find the content field and its value # Find the content field and its value
@ -182,19 +196,22 @@ class UnityConnection:
# Replace escaped quotes in content with regular quotes # Replace escaped quotes in content with regular quotes
content = decoded_data[content_start:content_end] content = decoded_data[content_start:content_end]
content = content.replace('\\"', '"') content = content.replace('\\"', '"')
decoded_data = decoded_data[:content_start] + content + decoded_data[content_end:] decoded_data = decoded_data[:content_start] + \
content + decoded_data[content_end:]
# Validate JSON format # Validate JSON format
json.loads(decoded_data) json.loads(decoded_data)
# If we get here, we have valid JSON # If we get here, we have valid JSON
logger.info(f"Received complete response ({len(data)} bytes)") logger.info(
f"Received complete response ({len(data)} bytes)")
return data return data
except json.JSONDecodeError: except json.JSONDecodeError:
# We haven't received a complete valid JSON response yet # We haven't received a complete valid JSON response yet
continue continue
except Exception as e: except Exception as e:
logger.warning(f"Error processing response chunk: {str(e)}") logger.warning(
f"Error processing response chunk: {str(e)}")
# Continue reading more chunks as this might not be the complete response # Continue reading more chunks as this might not be the complete response
continue continue
except socket.timeout: except socket.timeout:
@ -217,7 +234,8 @@ class UnityConnection:
def read_status_file() -> dict | None: def read_status_file() -> dict | None:
try: try:
status_files = sorted(Path.home().joinpath('.unity-mcp').glob('unity-mcp-status-*.json'), key=lambda p: p.stat().st_mtime, reverse=True) status_files = sorted(Path.home().joinpath(
'.unity-mcp').glob('unity-mcp-status-*.json'), key=lambda p: p.stat().st_mtime, reverse=True)
if not status_files: if not status_files:
return None return None
latest = status_files[0] latest = status_files[0]
@ -253,7 +271,8 @@ class UnityConnection:
payload = b'ping' payload = b'ping'
else: else:
command = {"type": command_type, "params": params or {}} command = {"type": command_type, "params": params or {}}
payload = json.dumps(command, ensure_ascii=False).encode('utf-8') payload = json.dumps(
command, ensure_ascii=False).encode('utf-8')
# Send/receive are serialized to protect the shared socket # Send/receive are serialized to protect the shared socket
with self._io_lock: with self._io_lock:
@ -280,7 +299,8 @@ class UnityConnection:
try: try:
response_data = self.receive_full_response(self.sock) response_data = self.receive_full_response(self.sock)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
logger.debug("recv %d bytes; mode=%s", len(response_data), mode) logger.debug("recv %d bytes; mode=%s",
len(response_data), mode)
finally: finally:
if restore_timeout is not None: if restore_timeout is not None:
self.sock.settimeout(restore_timeout) self.sock.settimeout(restore_timeout)
@ -295,11 +315,13 @@ class UnityConnection:
resp = json.loads(response_data.decode('utf-8')) resp = json.loads(response_data.decode('utf-8'))
if resp.get('status') == 'error': if resp.get('status') == 'error':
err = resp.get('error') or resp.get('message', 'Unknown Unity error') err = resp.get('error') or resp.get(
'message', 'Unknown Unity error')
raise Exception(err) raise Exception(err)
return resp.get('result', {}) return resp.get('result', {})
except Exception as e: except Exception as e:
logger.warning(f"Unity communication attempt {attempt+1} failed: {e}") logger.warning(
f"Unity communication attempt {attempt+1} failed: {e}")
try: try:
if self.sock: if self.sock:
self.sock.close() self.sock.close()
@ -310,7 +332,8 @@ class UnityConnection:
try: try:
new_port = PortDiscovery.discover_unity_port() new_port = PortDiscovery.discover_unity_port()
if new_port != self.port: if new_port != self.port:
logger.info(f"Unity port changed {self.port} -> {new_port}") logger.info(
f"Unity port changed {self.port} -> {new_port}")
self.port = new_port self.port = new_port
except Exception as de: except Exception as de:
logger.debug(f"Port discovery failed: {de}") logger.debug(f"Port discovery failed: {de}")
@ -324,11 +347,13 @@ class UnityConnection:
jitter = random.uniform(0.1, 0.3) jitter = random.uniform(0.1, 0.3)
# Fastretry for transient socket failures # Fastretry for transient socket failures
fast_error = isinstance(e, (ConnectionRefusedError, ConnectionResetError, TimeoutError)) fast_error = isinstance(
e, (ConnectionRefusedError, ConnectionResetError, TimeoutError))
if not fast_error: if not fast_error:
try: try:
err_no = getattr(e, 'errno', None) err_no = getattr(e, 'errno', None)
fast_error = err_no in (errno.ECONNREFUSED, errno.ECONNRESET, errno.ETIMEDOUT) fast_error = err_no in (
errno.ECONNREFUSED, errno.ECONNRESET, errno.ETIMEDOUT)
except Exception: except Exception:
pass pass
@ -345,9 +370,11 @@ class UnityConnection:
continue continue
raise raise
# Global Unity connection # Global Unity connection
_unity_connection = None _unity_connection = None
def get_unity_connection() -> UnityConnection: def get_unity_connection() -> UnityConnection:
"""Retrieve or establish a persistent Unity connection. """Retrieve or establish a persistent Unity connection.
@ -366,7 +393,8 @@ def get_unity_connection() -> UnityConnection:
_unity_connection = UnityConnection() _unity_connection = UnityConnection()
if not _unity_connection.connect(): if not _unity_connection.connect():
_unity_connection = None _unity_connection = None
raise ConnectionError("Could not connect to Unity. Ensure the Unity Editor and MCP Bridge are running.") raise ConnectionError(
"Could not connect to Unity. Ensure the Unity Editor and MCP Bridge are running.")
logger.info("Connected to Unity on startup") logger.info("Connected to Unity on startup")
return _unity_connection return _unity_connection
@ -400,7 +428,8 @@ def send_command_with_retry(command_type: str, params: Dict[str, Any], *, max_re
response = conn.send_command(command_type, params) response = conn.send_command(command_type, params)
retries = 0 retries = 0
while _is_reloading_response(response) and retries < max_retries: while _is_reloading_response(response) and retries < max_retries:
delay_ms = int(response.get("retry_after_ms", retry_ms)) if isinstance(response, dict) else retry_ms delay_ms = int(response.get("retry_after_ms", retry_ms)
) if isinstance(response, dict) else retry_ms
time.sleep(max(0.0, delay_ms / 1000.0)) time.sleep(max(0.0, delay_ms / 1000.0))
retries += 1 retries += 1
response = conn.send_command(command_type, params) response = conn.send_command(command_type, params)
@ -415,7 +444,8 @@ async def async_send_command_with_retry(command_type: str, params: Dict[str, Any
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
None, None,
lambda: send_command_with_retry(command_type, params, max_retries=max_retries, retry_ms=retry_ms), lambda: send_command_with_retry(
command_type, params, max_retries=max_retries, retry_ms=retry_ms),
) )
except Exception as e: except Exception as e:
# Return a structured error dict for consistency with other responses # Return a structured error dict for consistency with other responses