from config import config import contextlib from dataclasses import dataclass import errno import json import logging from pathlib import Path from port_discovery import PortDiscovery import random import socket import struct import threading import time from typing import Any, Dict # Configure logging using settings from config logging.basicConfig( level=getattr(logging, config.log_level), format=config.log_format ) logger = logging.getLogger("mcp-for-unity-server") # Module-level lock to guard global connection initialization _connection_lock = threading.Lock() # Maximum allowed framed payload size (64 MiB) FRAMED_MAX = 64 * 1024 * 1024 @dataclass class UnityConnection: """Manages the socket connection to the Unity Editor.""" host: str = config.unity_host port: int = None # Will be set dynamically sock: socket.socket = None # Socket for Unity communication use_framing: bool = False # Negotiated per-connection def __post_init__(self): """Set port from discovery if not explicitly provided""" if self.port is None: self.port = PortDiscovery.discover_unity_port() self._io_lock = threading.Lock() self._conn_lock = threading.Lock() def connect(self) -> bool: """Establish a connection to the Unity Editor.""" if self.sock: return True with self._conn_lock: if self.sock: return True try: # Bounded connect to avoid indefinite blocking connect_timeout = float( getattr(config, "connect_timeout", getattr(config, "connection_timeout", 1.0))) self.sock = socket.create_connection( (self.host, self.port), connect_timeout) # Disable Nagle's algorithm to reduce small RPC latency with contextlib.suppress(Exception): self.sock.setsockopt( socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) logger.debug(f"Connected to Unity at {self.host}:{self.port}") # Strict handshake: require FRAMING=1 try: require_framing = getattr(config, "require_framing", True) timeout = float(getattr(config, "handshake_timeout", 1.0)) self.sock.settimeout(timeout) buf = bytearray() deadline = time.monotonic() + timeout while time.monotonic() < deadline and len(buf) < 512: try: chunk = self.sock.recv(256) if not chunk: break buf.extend(chunk) if b"\n" in buf: break except socket.timeout: break text = bytes(buf).decode('ascii', errors='ignore').strip() if 'FRAMING=1' in text: self.use_framing = True logger.debug( 'Unity MCP handshake received: FRAMING=1 (strict)') else: if require_framing: # Best-effort plain-text advisory for legacy peers with contextlib.suppress(Exception): self.sock.sendall( b'Unity MCP requires FRAMING=1\n') raise ConnectionError( f'Unity MCP requires FRAMING=1, got: {text!r}') else: self.use_framing = False logger.warning( 'Unity MCP handshake missing FRAMING=1; proceeding in legacy mode by configuration') finally: self.sock.settimeout(config.connection_timeout) return True except Exception as e: logger.error(f"Failed to connect to Unity: {str(e)}") try: if self.sock: self.sock.close() except Exception: pass self.sock = None return False def disconnect(self): """Close the connection to the Unity Editor.""" if self.sock: try: self.sock.close() except Exception as e: logger.error(f"Error disconnecting from Unity: {str(e)}") finally: self.sock = None def _read_exact(self, sock: socket.socket, count: int) -> bytes: data = bytearray() while len(data) < count: chunk = sock.recv(count - len(data)) if not chunk: raise ConnectionError( "Connection closed before reading expected bytes") data.extend(chunk) return bytes(data) def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes: """Receive a complete response from Unity, handling chunked data.""" if self.use_framing: try: # Consume heartbeats, but do not hang indefinitely if only zero-length frames arrive heartbeat_count = 0 deadline = time.monotonic() + getattr(config, 'framed_receive_timeout', 2.0) while True: header = self._read_exact(sock, 8) payload_len = struct.unpack('>Q', header)[0] if payload_len == 0: # Heartbeat/no-op frame: consume and continue waiting for a data frame logger.debug("Received heartbeat frame (length=0)") heartbeat_count += 1 if heartbeat_count >= getattr(config, 'max_heartbeat_frames', 16) or time.monotonic() > deadline: # Treat as empty successful response to match C# server behavior logger.debug( "Heartbeat threshold reached; returning empty response") return b"" continue if payload_len > FRAMED_MAX: raise ValueError( f"Invalid framed length: {payload_len}") payload = self._read_exact(sock, payload_len) logger.debug( f"Received framed response ({len(payload)} bytes)") return payload except socket.timeout as e: logger.warning("Socket timeout during framed receive") raise TimeoutError("Timeout receiving Unity response") from e except Exception as e: logger.error(f"Error during framed receive: {str(e)}") raise chunks = [] # Respect the socket's currently configured timeout try: while True: chunk = sock.recv(buffer_size) if not chunk: if not chunks: raise Exception( "Connection closed before receiving data") break chunks.append(chunk) # Process the data received so far data = b''.join(chunks) decoded_data = data.decode('utf-8') # Check if we've received a complete response try: # Special case for ping-pong if decoded_data.strip().startswith('{"status":"success","result":{"message":"pong"'): logger.debug("Received ping response") return data # Handle escaped quotes in the content if '"content":' in decoded_data: # Find the content field and its value content_start = decoded_data.find('"content":') + 9 content_end = decoded_data.rfind('"', content_start) if content_end > content_start: # Replace escaped quotes in content with regular quotes content = decoded_data[content_start:content_end] content = content.replace('\\"', '"') decoded_data = decoded_data[:content_start] + \ content + decoded_data[content_end:] # Validate JSON format json.loads(decoded_data) # If we get here, we have valid JSON logger.info( f"Received complete response ({len(data)} bytes)") return data except json.JSONDecodeError: # We haven't received a complete valid JSON response yet continue except Exception as e: logger.warning( f"Error processing response chunk: {str(e)}") # Continue reading more chunks as this might not be the complete response continue except socket.timeout: logger.warning("Socket timeout during receive") raise Exception("Timeout receiving Unity response") except Exception as e: logger.error(f"Error during receive: {str(e)}") raise def send_command(self, command_type: str, params: Dict[str, Any] = None) -> Dict[str, Any]: """Send a command with retry/backoff and port rediscovery. Pings only when requested.""" # Defensive guard: catch empty/placeholder invocations early if not command_type: raise ValueError("MCP call missing command_type") if params is None: # Return a fast, structured error that clients can display without hanging return {"success": False, "error": "MCP call received with no parameters (client placeholder?)"} attempts = max(config.max_retries, 5) base_backoff = max(0.5, config.retry_delay) def read_status_file() -> dict | None: try: status_files = sorted(Path.home().joinpath( '.unity-mcp').glob('unity-mcp-status-*.json'), key=lambda p: p.stat().st_mtime, reverse=True) if not status_files: return None latest = status_files[0] with latest.open('r') as f: return json.load(f) except Exception: return None last_short_timeout = None # Preflight: if Unity reports reloading, return a structured hint so clients can retry politely try: status = read_status_file() if status and (status.get('reloading') or status.get('reason') == 'reloading'): return { "success": False, "state": "reloading", "retry_after_ms": int(config.reload_retry_ms), "error": "Unity domain reload in progress", "message": "Unity is reloading scripts; please retry shortly" } except Exception: pass for attempt in range(attempts + 1): try: # Ensure connected (handshake occurs within connect()) if not self.sock and not self.connect(): raise Exception("Could not connect to Unity") # Build payload if command_type == 'ping': payload = b'ping' else: command = {"type": command_type, "params": params or {}} payload = json.dumps( command, ensure_ascii=False).encode('utf-8') # Send/receive are serialized to protect the shared socket with self._io_lock: mode = 'framed' if self.use_framing else 'legacy' with contextlib.suppress(Exception): logger.debug( "send %d bytes; mode=%s; head=%s", len(payload), mode, (payload[:32]).decode('utf-8', 'ignore'), ) if self.use_framing: header = struct.pack('>Q', len(payload)) self.sock.sendall(header) self.sock.sendall(payload) else: self.sock.sendall(payload) # During retry bursts use a short receive timeout and ensure restoration restore_timeout = None if attempt > 0 and last_short_timeout is None: restore_timeout = self.sock.gettimeout() self.sock.settimeout(1.0) try: response_data = self.receive_full_response(self.sock) with contextlib.suppress(Exception): logger.debug("recv %d bytes; mode=%s", len(response_data), mode) finally: if restore_timeout is not None: self.sock.settimeout(restore_timeout) last_short_timeout = None # Parse if command_type == 'ping': resp = json.loads(response_data.decode('utf-8')) if resp.get('status') == 'success' and resp.get('result', {}).get('message') == 'pong': return {"message": "pong"} raise Exception("Ping unsuccessful") resp = json.loads(response_data.decode('utf-8')) if resp.get('status') == 'error': err = resp.get('error') or resp.get( 'message', 'Unknown Unity error') raise Exception(err) return resp.get('result', {}) except Exception as e: logger.warning( f"Unity communication attempt {attempt+1} failed: {e}") try: if self.sock: self.sock.close() finally: self.sock = None # Re-discover port each time try: new_port = PortDiscovery.discover_unity_port() if new_port != self.port: logger.info( f"Unity port changed {self.port} -> {new_port}") self.port = new_port except Exception as de: logger.debug(f"Port discovery failed: {de}") if attempt < attempts: # Heartbeat-aware, jittered backoff status = read_status_file() # Base exponential backoff backoff = base_backoff * (2 ** attempt) # Decorrelated jitter multiplier jitter = random.uniform(0.1, 0.3) # Fast‑retry for transient socket failures fast_error = isinstance( e, (ConnectionRefusedError, ConnectionResetError, TimeoutError)) if not fast_error: try: err_no = getattr(e, 'errno', None) fast_error = err_no in ( errno.ECONNREFUSED, errno.ECONNRESET, errno.ETIMEDOUT) except Exception: pass # Cap backoff depending on state if status and status.get('reloading'): cap = 0.8 elif fast_error: cap = 0.25 else: cap = 3.0 sleep_s = min(cap, jitter * (2 ** attempt)) time.sleep(sleep_s) continue raise # Global Unity connection _unity_connection = None def get_unity_connection() -> UnityConnection: """Retrieve or establish a persistent Unity connection. Note: Do NOT ping on every retrieval to avoid connection storms. Rely on send_command() exceptions to detect broken sockets and reconnect there. """ global _unity_connection if _unity_connection is not None: return _unity_connection # Double-checked locking to avoid concurrent socket creation with _connection_lock: if _unity_connection is not None: return _unity_connection logger.info("Creating new Unity connection") _unity_connection = UnityConnection() if not _unity_connection.connect(): _unity_connection = None raise ConnectionError( "Could not connect to Unity. Ensure the Unity Editor and MCP Bridge are running.") logger.info("Connected to Unity on startup") return _unity_connection # ----------------------------- # Centralized retry helpers # ----------------------------- def _is_reloading_response(resp: dict) -> bool: """Return True if the Unity response indicates the editor is reloading.""" if not isinstance(resp, dict): return False if resp.get("state") == "reloading": return True message_text = (resp.get("message") or resp.get("error") or "").lower() return "reload" in message_text def send_command_with_retry(command_type: str, params: Dict[str, Any], *, max_retries: int | None = None, retry_ms: int | None = None) -> Dict[str, Any]: """Send a command via the shared connection, waiting politely through Unity reloads. Uses config.reload_retry_ms and config.reload_max_retries by default. Preserves the structured failure if retries are exhausted. """ conn = get_unity_connection() if max_retries is None: max_retries = getattr(config, "reload_max_retries", 40) if retry_ms is None: retry_ms = getattr(config, "reload_retry_ms", 250) response = conn.send_command(command_type, params) retries = 0 while _is_reloading_response(response) and retries < max_retries: delay_ms = int(response.get("retry_after_ms", retry_ms) ) if isinstance(response, dict) else retry_ms time.sleep(max(0.0, delay_ms / 1000.0)) retries += 1 response = conn.send_command(command_type, params) return response async def async_send_command_with_retry(command_type: str, params: Dict[str, Any], *, loop=None, max_retries: int | None = None, retry_ms: int | None = None) -> Dict[str, Any]: """Async wrapper that runs the blocking retry helper in a thread pool.""" try: import asyncio # local import to avoid mandatory asyncio dependency for sync callers if loop is None: loop = asyncio.get_running_loop() return await loop.run_in_executor( None, lambda: send_command_with_retry( command_type, params, max_retries=max_retries, retry_ms=retry_ms), ) except Exception as e: # Return a structured error dict for consistency with other responses return {"success": False, "error": f"Python async retry helper failed: {str(e)}"}