unity-mcp/MCPForUnity/UnityMcpServer~/src/unity_connection.py

452 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from config import config
import contextlib
from dataclasses import dataclass
import errno
import json
import logging
from pathlib import Path
from port_discovery import PortDiscovery
import random
import socket
import struct
import threading
import time
from typing import Any, Dict
from models import MCPResponse
# Configure logging using settings from config
logging.basicConfig(
level=getattr(logging, config.log_level),
format=config.log_format
)
logger = logging.getLogger("mcp-for-unity-server")
# Module-level lock to guard global connection initialization
_connection_lock = threading.Lock()
# Maximum allowed framed payload size (64 MiB)
FRAMED_MAX = 64 * 1024 * 1024
@dataclass
class UnityConnection:
"""Manages the socket connection to the Unity Editor."""
host: str = config.unity_host
port: int = None # Will be set dynamically
sock: socket.socket = None # Socket for Unity communication
use_framing: bool = False # Negotiated per-connection
def __post_init__(self):
"""Set port from discovery if not explicitly provided"""
if self.port is None:
self.port = PortDiscovery.discover_unity_port()
self._io_lock = threading.Lock()
self._conn_lock = threading.Lock()
def connect(self) -> bool:
"""Establish a connection to the Unity Editor."""
if self.sock:
return True
with self._conn_lock:
if self.sock:
return True
try:
# Bounded connect to avoid indefinite blocking
connect_timeout = float(
getattr(config, "connect_timeout", getattr(config, "connection_timeout", 1.0)))
self.sock = socket.create_connection(
(self.host, self.port), connect_timeout)
# Disable Nagle's algorithm to reduce small RPC latency
with contextlib.suppress(Exception):
self.sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
logger.debug(f"Connected to Unity at {self.host}:{self.port}")
# Strict handshake: require FRAMING=1
try:
require_framing = getattr(config, "require_framing", True)
timeout = float(getattr(config, "handshake_timeout", 1.0))
self.sock.settimeout(timeout)
buf = bytearray()
deadline = time.monotonic() + timeout
while time.monotonic() < deadline and len(buf) < 512:
try:
chunk = self.sock.recv(256)
if not chunk:
break
buf.extend(chunk)
if b"\n" in buf:
break
except socket.timeout:
break
text = bytes(buf).decode('ascii', errors='ignore').strip()
if 'FRAMING=1' in text:
self.use_framing = True
logger.debug(
'MCP for Unity handshake received: FRAMING=1 (strict)')
else:
if require_framing:
# Best-effort plain-text advisory for legacy peers
with contextlib.suppress(Exception):
self.sock.sendall(
b'MCP for Unity requires FRAMING=1\n')
raise ConnectionError(
f'MCP for Unity requires FRAMING=1, got: {text!r}')
else:
self.use_framing = False
logger.warning(
'MCP for Unity handshake missing FRAMING=1; proceeding in legacy mode by configuration')
finally:
self.sock.settimeout(config.connection_timeout)
return True
except Exception as e:
logger.error(f"Failed to connect to Unity: {str(e)}")
try:
if self.sock:
self.sock.close()
except Exception:
pass
self.sock = None
return False
def disconnect(self):
"""Close the connection to the Unity Editor."""
if self.sock:
try:
self.sock.close()
except Exception as e:
logger.error(f"Error disconnecting from Unity: {str(e)}")
finally:
self.sock = None
def _read_exact(self, sock: socket.socket, count: int) -> bytes:
data = bytearray()
while len(data) < count:
chunk = sock.recv(count - len(data))
if not chunk:
raise ConnectionError(
"Connection closed before reading expected bytes")
data.extend(chunk)
return bytes(data)
def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes:
"""Receive a complete response from Unity, handling chunked data."""
if self.use_framing:
try:
# Consume heartbeats, but do not hang indefinitely if only zero-length frames arrive
heartbeat_count = 0
deadline = time.monotonic() + getattr(config, 'framed_receive_timeout', 2.0)
while True:
header = self._read_exact(sock, 8)
payload_len = struct.unpack('>Q', header)[0]
if payload_len == 0:
# Heartbeat/no-op frame: consume and continue waiting for a data frame
logger.debug("Received heartbeat frame (length=0)")
heartbeat_count += 1
if heartbeat_count >= getattr(config, 'max_heartbeat_frames', 16) or time.monotonic() > deadline:
# Treat as empty successful response to match C# server behavior
logger.debug(
"Heartbeat threshold reached; returning empty response")
return b""
continue
if payload_len > FRAMED_MAX:
raise ValueError(
f"Invalid framed length: {payload_len}")
payload = self._read_exact(sock, payload_len)
logger.debug(
f"Received framed response ({len(payload)} bytes)")
return payload
except socket.timeout as e:
logger.warning("Socket timeout during framed receive")
raise TimeoutError("Timeout receiving Unity response") from e
except Exception as e:
logger.error(f"Error during framed receive: {str(e)}")
raise
chunks = []
# Respect the socket's currently configured timeout
try:
while True:
chunk = sock.recv(buffer_size)
if not chunk:
if not chunks:
raise Exception(
"Connection closed before receiving data")
break
chunks.append(chunk)
# Process the data received so far
data = b''.join(chunks)
decoded_data = data.decode('utf-8')
# Check if we've received a complete response
try:
# Special case for ping-pong
if decoded_data.strip().startswith('{"status":"success","result":{"message":"pong"'):
logger.debug("Received ping response")
return data
# Handle escaped quotes in the content
if '"content":' in decoded_data:
# Find the content field and its value
content_start = decoded_data.find('"content":') + 9
content_end = decoded_data.rfind('"', content_start)
if content_end > content_start:
# Replace escaped quotes in content with regular quotes
content = decoded_data[content_start:content_end]
content = content.replace('\\"', '"')
decoded_data = decoded_data[:content_start] + \
content + decoded_data[content_end:]
# Validate JSON format
json.loads(decoded_data)
# If we get here, we have valid JSON
logger.info(
f"Received complete response ({len(data)} bytes)")
return data
except json.JSONDecodeError:
# We haven't received a complete valid JSON response yet
continue
except Exception as e:
logger.warning(
f"Error processing response chunk: {str(e)}")
# Continue reading more chunks as this might not be the complete response
continue
except socket.timeout:
logger.warning("Socket timeout during receive")
raise Exception("Timeout receiving Unity response")
except Exception as e:
logger.error(f"Error during receive: {str(e)}")
raise
def send_command(self, command_type: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send a command with retry/backoff and port rediscovery. Pings only when requested."""
# Defensive guard: catch empty/placeholder invocations early
if not command_type:
raise ValueError("MCP call missing command_type")
if params is None:
return MCPResponse(success=False, error="MCP call received with no parameters (client placeholder?)")
attempts = max(config.max_retries, 5)
base_backoff = max(0.5, config.retry_delay)
def read_status_file() -> dict | None:
try:
status_files = sorted(Path.home().joinpath(
'.unity-mcp').glob('unity-mcp-status-*.json'), key=lambda p: p.stat().st_mtime, reverse=True)
if not status_files:
return None
latest = status_files[0]
with latest.open('r') as f:
return json.load(f)
except Exception:
return None
last_short_timeout = None
# Preflight: if Unity reports reloading, return a structured hint so clients can retry politely
try:
status = read_status_file()
if status and (status.get('reloading') or status.get('reason') == 'reloading'):
return MCPResponse(
success=False,
error="Unity domain reload in progress, please try again shortly",
data={"state": "reloading", "retry_after_ms": int(
config.reload_retry_ms)}
)
except Exception:
pass
for attempt in range(attempts + 1):
try:
# Ensure connected (handshake occurs within connect())
if not self.sock and not self.connect():
raise Exception("Could not connect to Unity")
# Build payload
if command_type == 'ping':
payload = b'ping'
else:
command = {"type": command_type, "params": params or {}}
payload = json.dumps(
command, ensure_ascii=False).encode('utf-8')
# Send/receive are serialized to protect the shared socket
with self._io_lock:
mode = 'framed' if self.use_framing else 'legacy'
with contextlib.suppress(Exception):
logger.debug(
"send %d bytes; mode=%s; head=%s",
len(payload),
mode,
(payload[:32]).decode('utf-8', 'ignore'),
)
if self.use_framing:
header = struct.pack('>Q', len(payload))
self.sock.sendall(header)
self.sock.sendall(payload)
else:
self.sock.sendall(payload)
# During retry bursts use a short receive timeout and ensure restoration
restore_timeout = None
if attempt > 0 and last_short_timeout is None:
restore_timeout = self.sock.gettimeout()
self.sock.settimeout(1.0)
try:
response_data = self.receive_full_response(self.sock)
with contextlib.suppress(Exception):
logger.debug("recv %d bytes; mode=%s",
len(response_data), mode)
finally:
if restore_timeout is not None:
self.sock.settimeout(restore_timeout)
last_short_timeout = None
# Parse
if command_type == 'ping':
resp = json.loads(response_data.decode('utf-8'))
if resp.get('status') == 'success' and resp.get('result', {}).get('message') == 'pong':
return {"message": "pong"}
raise Exception("Ping unsuccessful")
resp = json.loads(response_data.decode('utf-8'))
if resp.get('status') == 'error':
err = resp.get('error') or resp.get(
'message', 'Unknown Unity error')
raise Exception(err)
return resp.get('result', {})
except Exception as e:
logger.warning(
f"Unity communication attempt {attempt+1} failed: {e}")
try:
if self.sock:
self.sock.close()
finally:
self.sock = None
# Re-discover port each time
try:
new_port = PortDiscovery.discover_unity_port()
if new_port != self.port:
logger.info(
f"Unity port changed {self.port} -> {new_port}")
self.port = new_port
except Exception as de:
logger.debug(f"Port discovery failed: {de}")
if attempt < attempts:
# Heartbeat-aware, jittered backoff
status = read_status_file()
# Base exponential backoff
backoff = base_backoff * (2 ** attempt)
# Decorrelated jitter multiplier
jitter = random.uniform(0.1, 0.3)
# Fastretry for transient socket failures
fast_error = isinstance(
e, (ConnectionRefusedError, ConnectionResetError, TimeoutError))
if not fast_error:
try:
err_no = getattr(e, 'errno', None)
fast_error = err_no in (
errno.ECONNREFUSED, errno.ECONNRESET, errno.ETIMEDOUT)
except Exception:
pass
# Cap backoff depending on state
if status and status.get('reloading'):
cap = 0.8
elif fast_error:
cap = 0.25
else:
cap = 3.0
sleep_s = min(cap, jitter * (2 ** attempt))
time.sleep(sleep_s)
continue
raise
# Global Unity connection
_unity_connection = None
def get_unity_connection() -> UnityConnection:
"""Retrieve or establish a persistent Unity connection.
Note: Do NOT ping on every retrieval to avoid connection storms. Rely on
send_command() exceptions to detect broken sockets and reconnect there.
"""
global _unity_connection
if _unity_connection is not None:
return _unity_connection
# Double-checked locking to avoid concurrent socket creation
with _connection_lock:
if _unity_connection is not None:
return _unity_connection
logger.info("Creating new Unity connection")
_unity_connection = UnityConnection()
if not _unity_connection.connect():
_unity_connection = None
raise ConnectionError(
"Could not connect to Unity. Ensure the Unity Editor and MCP Bridge are running.")
logger.info("Connected to Unity on startup")
return _unity_connection
# -----------------------------
# Centralized retry helpers
# -----------------------------
def _is_reloading_response(resp: dict) -> bool:
"""Return True if the Unity response indicates the editor is reloading."""
if not isinstance(resp, dict):
return False
if resp.get("state") == "reloading":
return True
message_text = (resp.get("message") or resp.get("error") or "").lower()
return "reload" in message_text
def send_command_with_retry(command_type: str, params: Dict[str, Any], *, max_retries: int | None = None, retry_ms: int | None = None) -> Dict[str, Any]:
"""Send a command via the shared connection, waiting politely through Unity reloads.
Uses config.reload_retry_ms and config.reload_max_retries by default. Preserves the
structured failure if retries are exhausted.
"""
conn = get_unity_connection()
if max_retries is None:
max_retries = getattr(config, "reload_max_retries", 40)
if retry_ms is None:
retry_ms = getattr(config, "reload_retry_ms", 250)
response = conn.send_command(command_type, params)
retries = 0
while _is_reloading_response(response) and retries < max_retries:
delay_ms = int(response.get("retry_after_ms", retry_ms)
) if isinstance(response, dict) else retry_ms
time.sleep(max(0.0, delay_ms / 1000.0))
retries += 1
response = conn.send_command(command_type, params)
return response
async def async_send_command_with_retry(command_type: str, params: dict[str, Any], *, loop=None, max_retries: int | None = None, retry_ms: int | None = None) -> dict[str, Any] | MCPResponse:
"""Async wrapper that runs the blocking retry helper in a thread pool."""
try:
import asyncio # local import to avoid mandatory asyncio dependency for sync callers
if loop is None:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: send_command_with_retry(
command_type, params, max_retries=max_retries, retry_ms=retry_ms),
)
except Exception as e:
return MCPResponse(success=False, error=str(e))