Server: Robust shutdown on stdio detach (signals, stdin/parent monitor, forced exit) (#363)
* Server: robust shutdown on stdio detach (signals, stdin/parent monitor, forced exit)\nTests: move telemetry tests to tests/ and convert to asserts * Server: simplify _force_exit to os._exit; guard exit timers to avoid duplicates; fix Windows ValueError in parent monitor; tests: add autouse cwd fixture for telemetry to locate pyproject.toml * Server: add DEBUG logs for transient stdin checks and monitor thread errors * Mirror shutdown improvements: signal handlers, stdin/parent monitor, guarded exit timers, and os._exit force-exit in UnityMcpServer~ entry pointsmain
parent
040eb6d701
commit
ca01fc7610
|
|
@ -4,6 +4,9 @@ import logging
|
|||
from logging.handlers import RotatingFileHandler
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import AsyncIterator, Dict, Any
|
||||
from config import config
|
||||
from tools import register_all_tools
|
||||
|
|
@ -64,6 +67,10 @@ except Exception:
|
|||
# Global connection state
|
||||
_unity_connection: UnityConnection = None
|
||||
|
||||
# Global shutdown coordination
|
||||
_shutdown_flag = threading.Event()
|
||||
_exit_timer_scheduled = threading.Event()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server: FastMCP) -> AsyncIterator[Dict[str, Any]]:
|
||||
|
|
@ -186,9 +193,98 @@ register_all_tools(mcp)
|
|||
register_all_resources(mcp)
|
||||
|
||||
|
||||
def _force_exit(code: int = 0):
|
||||
"""Force process exit, bypassing any background threads that might linger."""
|
||||
os._exit(code)
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating shutdown...")
|
||||
_shutdown_flag.set()
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(1.0, _force_exit, args=(0,)).start()
|
||||
|
||||
|
||||
def _monitor_stdin():
|
||||
"""Background thread to detect stdio detach (stdin EOF) or parent exit."""
|
||||
try:
|
||||
parent_pid = os.getppid() if hasattr(os, "getppid") else None
|
||||
while not _shutdown_flag.is_set():
|
||||
if _shutdown_flag.wait(0.5):
|
||||
break
|
||||
|
||||
if parent_pid is not None:
|
||||
try:
|
||||
os.kill(parent_pid, 0)
|
||||
except ValueError:
|
||||
# Signal 0 unsupported on this platform (e.g., Windows); disable parent probing
|
||||
parent_pid = None
|
||||
except (ProcessLookupError, OSError):
|
||||
logger.info(f"Parent process {parent_pid} no longer exists; shutting down")
|
||||
break
|
||||
|
||||
try:
|
||||
if sys.stdin.closed:
|
||||
logger.info("stdin.closed is True; client disconnected")
|
||||
break
|
||||
fd = sys.stdin.fileno()
|
||||
if fd < 0:
|
||||
logger.info("stdin fd invalid; client disconnected")
|
||||
break
|
||||
except (ValueError, OSError, AttributeError):
|
||||
# Closed pipe or unavailable stdin
|
||||
break
|
||||
except Exception:
|
||||
# Ignore transient errors
|
||||
logger.debug("Transient error checking stdin", exc_info=True)
|
||||
|
||||
if not _shutdown_flag.is_set():
|
||||
logger.info("Client disconnected (stdin or parent), initiating shutdown...")
|
||||
_shutdown_flag.set()
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(0.5, _force_exit, args=(0,)).start()
|
||||
except Exception:
|
||||
# Never let monitor thread crash the process
|
||||
logger.debug("Monitor thread error", exc_info=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for uvx and console scripts."""
|
||||
mcp.run(transport='stdio')
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
if hasattr(signal, "SIGPIPE"):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
if hasattr(signal, "SIGBREAK"):
|
||||
signal.signal(signal.SIGBREAK, _signal_handler)
|
||||
except Exception:
|
||||
# Signals can fail in some environments
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_monitor_stdin, daemon=True)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
mcp.run(transport='stdio')
|
||||
logger.info("FastMCP run() returned (stdin EOF or disconnect)")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Server interrupted; shutting down")
|
||||
_shutdown_flag.set()
|
||||
except BrokenPipeError:
|
||||
logger.info("Broken pipe; shutting down")
|
||||
_shutdown_flag.set()
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}", exc_info=True)
|
||||
_shutdown_flag.set()
|
||||
_force_exit(1)
|
||||
finally:
|
||||
_shutdown_flag.set()
|
||||
logger.info("Server main loop exited")
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(0.5, _force_exit, args=(0,)).start()
|
||||
|
||||
|
||||
# Run the server
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ import logging
|
|||
from logging.handlers import RotatingFileHandler
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
import sys
|
||||
import signal
|
||||
import threading
|
||||
from typing import AsyncIterator, Dict, Any
|
||||
from config import config
|
||||
from tools import register_all_tools
|
||||
|
|
@ -64,6 +67,10 @@ except Exception:
|
|||
# Global connection state
|
||||
_unity_connection: UnityConnection = None
|
||||
|
||||
# Global shutdown coordination
|
||||
_shutdown_flag = threading.Event()
|
||||
_exit_timer_scheduled = threading.Event()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server: FastMCP) -> AsyncIterator[Dict[str, Any]]:
|
||||
|
|
@ -186,9 +193,98 @@ register_all_tools(mcp)
|
|||
register_all_resources(mcp)
|
||||
|
||||
|
||||
def _force_exit(code: int = 0):
|
||||
"""Force process exit, bypassing any background threads that might linger."""
|
||||
os._exit(code)
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating shutdown...")
|
||||
_shutdown_flag.set()
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(1.0, _force_exit, args=(0,)).start()
|
||||
|
||||
|
||||
def _monitor_stdin():
|
||||
"""Background thread to detect stdio detach (stdin EOF) or parent exit."""
|
||||
try:
|
||||
parent_pid = os.getppid() if hasattr(os, "getppid") else None
|
||||
while not _shutdown_flag.is_set():
|
||||
if _shutdown_flag.wait(0.5):
|
||||
break
|
||||
|
||||
if parent_pid is not None:
|
||||
try:
|
||||
os.kill(parent_pid, 0)
|
||||
except ValueError:
|
||||
# Signal 0 unsupported on this platform (e.g., Windows); disable parent probing
|
||||
parent_pid = None
|
||||
except (ProcessLookupError, OSError):
|
||||
logger.info(f"Parent process {parent_pid} no longer exists; shutting down")
|
||||
break
|
||||
|
||||
try:
|
||||
if sys.stdin.closed:
|
||||
logger.info("stdin.closed is True; client disconnected")
|
||||
break
|
||||
fd = sys.stdin.fileno()
|
||||
if fd < 0:
|
||||
logger.info("stdin fd invalid; client disconnected")
|
||||
break
|
||||
except (ValueError, OSError, AttributeError):
|
||||
# Closed pipe or unavailable stdin
|
||||
break
|
||||
except Exception:
|
||||
# Ignore transient errors
|
||||
logger.debug("Transient error checking stdin", exc_info=True)
|
||||
|
||||
if not _shutdown_flag.is_set():
|
||||
logger.info("Client disconnected (stdin or parent), initiating shutdown...")
|
||||
_shutdown_flag.set()
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(0.5, _force_exit, args=(0,)).start()
|
||||
except Exception:
|
||||
# Never let monitor thread crash the process
|
||||
logger.debug("Monitor thread error", exc_info=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for uvx and console scripts."""
|
||||
mcp.run(transport='stdio')
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
if hasattr(signal, "SIGPIPE"):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
if hasattr(signal, "SIGBREAK"):
|
||||
signal.signal(signal.SIGBREAK, _signal_handler)
|
||||
except Exception:
|
||||
# Signals can fail in some environments
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_monitor_stdin, daemon=True)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
mcp.run(transport='stdio')
|
||||
logger.info("FastMCP run() returned (stdin EOF or disconnect)")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Server interrupted; shutting down")
|
||||
_shutdown_flag.set()
|
||||
except BrokenPipeError:
|
||||
logger.info("Broken pipe; shutting down")
|
||||
_shutdown_flag.set()
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}", exc_info=True)
|
||||
_shutdown_flag.set()
|
||||
_force_exit(1)
|
||||
finally:
|
||||
_shutdown_flag.set()
|
||||
logger.info("Server main loop exited")
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(0.5, _force_exit, args=(0,)).start()
|
||||
|
||||
|
||||
# Run the server
|
||||
|
|
|
|||
|
|
@ -1,161 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for MCP for Unity Telemetry System
|
||||
Run this to verify telemetry is working correctly
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add src to Python path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
|
||||
def test_telemetry_basic():
|
||||
"""Test basic telemetry functionality"""
|
||||
# Avoid stdout noise in tests
|
||||
|
||||
try:
|
||||
from telemetry import (
|
||||
get_telemetry, record_telemetry, record_milestone,
|
||||
RecordType, MilestoneType, is_telemetry_enabled
|
||||
)
|
||||
pass
|
||||
except ImportError as e:
|
||||
# Silent failure path for tests
|
||||
return False
|
||||
|
||||
# Test telemetry enabled status
|
||||
_ = is_telemetry_enabled()
|
||||
|
||||
# Test basic record
|
||||
try:
|
||||
record_telemetry(RecordType.VERSION, {
|
||||
"version": "3.0.2",
|
||||
"test_run": True
|
||||
})
|
||||
pass
|
||||
except Exception as e:
|
||||
# Silent failure path for tests
|
||||
return False
|
||||
|
||||
# Test milestone recording
|
||||
try:
|
||||
is_first = record_milestone(MilestoneType.FIRST_STARTUP, {
|
||||
"test_mode": True
|
||||
})
|
||||
_ = is_first
|
||||
except Exception as e:
|
||||
# Silent failure path for tests
|
||||
return False
|
||||
|
||||
# Test telemetry collector
|
||||
try:
|
||||
collector = get_telemetry()
|
||||
_ = collector
|
||||
except Exception as e:
|
||||
# Silent failure path for tests
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_telemetry_disabled():
|
||||
"""Test telemetry with disabled state"""
|
||||
# Silent for tests
|
||||
|
||||
# Set environment variable to disable telemetry
|
||||
os.environ["DISABLE_TELEMETRY"] = "true"
|
||||
|
||||
# Re-import to get fresh config
|
||||
import importlib
|
||||
import telemetry
|
||||
importlib.reload(telemetry)
|
||||
|
||||
from telemetry import is_telemetry_enabled, record_telemetry, RecordType
|
||||
|
||||
_ = is_telemetry_enabled()
|
||||
|
||||
if not is_telemetry_enabled():
|
||||
pass
|
||||
|
||||
# Test that records are ignored when disabled
|
||||
record_telemetry(RecordType.USAGE, {"test": "should_be_ignored"})
|
||||
pass
|
||||
|
||||
return True
|
||||
else:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def test_data_storage():
|
||||
"""Test data storage functionality"""
|
||||
# Silent for tests
|
||||
|
||||
try:
|
||||
from telemetry import get_telemetry
|
||||
|
||||
collector = get_telemetry()
|
||||
data_dir = collector.config.data_dir
|
||||
|
||||
_ = (data_dir, collector.config.uuid_file,
|
||||
collector.config.milestones_file)
|
||||
|
||||
# Check if files exist
|
||||
if collector.config.uuid_file.exists():
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
if collector.config.milestones_file.exists():
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Silent failure path for tests
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all telemetry tests"""
|
||||
# Silent runner for CI
|
||||
|
||||
tests = [
|
||||
test_telemetry_basic,
|
||||
test_data_storage,
|
||||
test_telemetry_disabled,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
if test():
|
||||
passed += 1
|
||||
pass
|
||||
else:
|
||||
failed += 1
|
||||
pass
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pass
|
||||
|
||||
_ = (passed, failed)
|
||||
|
||||
if failed == 0:
|
||||
pass
|
||||
return True
|
||||
else:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -4,6 +4,9 @@ import logging
|
|||
from logging.handlers import RotatingFileHandler
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import AsyncIterator, Dict, Any
|
||||
from config import config
|
||||
from tools import register_all_tools
|
||||
|
|
@ -63,6 +66,10 @@ except Exception:
|
|||
# Global connection state
|
||||
_unity_connection: UnityConnection = None
|
||||
|
||||
# Global shutdown coordination
|
||||
_shutdown_flag = threading.Event()
|
||||
_exit_timer_scheduled = threading.Event()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server: FastMCP) -> AsyncIterator[Dict[str, Any]]:
|
||||
|
|
@ -189,6 +196,99 @@ def asset_creation_strategy() -> str:
|
|||
)
|
||||
|
||||
|
||||
def _force_exit(code: int = 0):
|
||||
"""Force process exit, bypassing any background threads that might linger."""
|
||||
os._exit(code)
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating shutdown...")
|
||||
_shutdown_flag.set()
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(1.0, _force_exit, args=(0,)).start()
|
||||
|
||||
|
||||
def _monitor_stdin():
|
||||
"""Background thread to detect stdio detach (stdin EOF) or parent exit."""
|
||||
try:
|
||||
parent_pid = os.getppid() if hasattr(os, "getppid") else None
|
||||
while not _shutdown_flag.is_set():
|
||||
if _shutdown_flag.wait(0.5):
|
||||
break
|
||||
|
||||
if parent_pid is not None:
|
||||
try:
|
||||
os.kill(parent_pid, 0)
|
||||
except ValueError:
|
||||
# Signal 0 unsupported on this platform (e.g., Windows); disable parent probing
|
||||
parent_pid = None
|
||||
except (ProcessLookupError, OSError):
|
||||
logger.info(f"Parent process {parent_pid} no longer exists; shutting down")
|
||||
break
|
||||
|
||||
try:
|
||||
if sys.stdin.closed:
|
||||
logger.info("stdin.closed is True; client disconnected")
|
||||
break
|
||||
fd = sys.stdin.fileno()
|
||||
if fd < 0:
|
||||
logger.info("stdin fd invalid; client disconnected")
|
||||
break
|
||||
except (ValueError, OSError, AttributeError):
|
||||
# Closed pipe or unavailable stdin
|
||||
break
|
||||
except Exception:
|
||||
# Ignore transient errors
|
||||
logger.debug("Transient error checking stdin", exc_info=True)
|
||||
|
||||
if not _shutdown_flag.is_set():
|
||||
logger.info("Client disconnected (stdin or parent), initiating shutdown...")
|
||||
_shutdown_flag.set()
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(0.5, _force_exit, args=(0,)).start()
|
||||
except Exception:
|
||||
# Never let monitor thread crash the process
|
||||
logger.debug("Monitor thread error", exc_info=True)
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
if hasattr(signal, "SIGPIPE"):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
if hasattr(signal, "SIGBREAK"):
|
||||
signal.signal(signal.SIGBREAK, _signal_handler)
|
||||
except Exception:
|
||||
# Signals can fail in some environments
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_monitor_stdin, daemon=True)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
mcp.run(transport='stdio')
|
||||
logger.info("FastMCP run() returned (stdin EOF or disconnect)")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Server interrupted; shutting down")
|
||||
_shutdown_flag.set()
|
||||
except BrokenPipeError:
|
||||
logger.info("Broken pipe; shutting down")
|
||||
_shutdown_flag.set()
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}", exc_info=True)
|
||||
_shutdown_flag.set()
|
||||
_force_exit(1)
|
||||
finally:
|
||||
_shutdown_flag.set()
|
||||
logger.info("Server main loop exited")
|
||||
if not _exit_timer_scheduled.is_set():
|
||||
_exit_timer_scheduled.set()
|
||||
threading.Timer(0.5, _force_exit, args=(0,)).start()
|
||||
|
||||
|
||||
# Run the server
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport='stdio')
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
# Allow importing telemetry from Server
|
||||
SERVER_DIR = Path(__file__).resolve().parents[1] / "Server"
|
||||
sys.path.insert(0, str(SERVER_DIR))
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cwd(monkeypatch):
|
||||
# Ensure telemetry package can locate pyproject.toml via cwd-relative lookup
|
||||
src_dir = Path(__file__).resolve().parents[1] / "MCPForUnity" / "UnityMcpServer~" / "src"
|
||||
if not src_dir.exists():
|
||||
# Fallback to UnityMcpBridge layout if MCPForUnity path not present
|
||||
fallback = Path(__file__).resolve().parents[1] / "UnityMcpBridge" / "UnityMcpServer~" / "src"
|
||||
if fallback.exists():
|
||||
src_dir = fallback
|
||||
monkeypatch.chdir(src_dir)
|
||||
|
||||
|
||||
def test_telemetry_basic():
|
||||
from telemetry import (
|
||||
get_telemetry,
|
||||
record_telemetry,
|
||||
record_milestone,
|
||||
RecordType,
|
||||
MilestoneType,
|
||||
is_telemetry_enabled,
|
||||
)
|
||||
|
||||
assert isinstance(is_telemetry_enabled(), bool)
|
||||
record_telemetry(RecordType.VERSION, {"version": "3.0.2", "test_run": True})
|
||||
first = record_milestone(MilestoneType.FIRST_STARTUP, {"test_mode": True})
|
||||
assert isinstance(first, bool)
|
||||
assert get_telemetry() is not None
|
||||
|
||||
|
||||
def test_telemetry_disabled(monkeypatch):
|
||||
monkeypatch.setenv("DISABLE_TELEMETRY", "true")
|
||||
import telemetry
|
||||
|
||||
importlib.reload(telemetry)
|
||||
from telemetry import is_telemetry_enabled, record_telemetry, RecordType
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
record_telemetry(RecordType.USAGE, {"test": "ignored"})
|
||||
|
||||
# restore module state for later tests
|
||||
monkeypatch.delenv("DISABLE_TELEMETRY", raising=False)
|
||||
importlib.reload(telemetry)
|
||||
|
||||
|
||||
def test_data_storage():
|
||||
from telemetry import get_telemetry
|
||||
|
||||
coll = get_telemetry()
|
||||
cfg = coll.config
|
||||
assert cfg.data_dir is not None
|
||||
assert cfg.uuid_file is not None
|
||||
assert cfg.milestones_file is not None
|
||||
|
||||
Loading…
Reference in New Issue