"""Tests for ApiKeyService: validation, caching, retries, and singleton lifecycle.""" import asyncio import time from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from services.api_key_service import ApiKeyService, ValidationResult @pytest.fixture(autouse=True) def _reset_singleton(): """Reset the ApiKeyService singleton between tests.""" ApiKeyService._instance = None yield ApiKeyService._instance = None def _make_service( validation_url="https://auth.example.com/validate", cache_ttl=300.0, service_token_header=None, service_token=None, ): return ApiKeyService( validation_url=validation_url, cache_ttl=cache_ttl, service_token_header=service_token_header, service_token=service_token, ) def _mock_response(status_code=200, json_data=None): resp = MagicMock(spec=httpx.Response) resp.status_code = status_code resp.json.return_value = json_data or {} return resp # --------------------------------------------------------------------------- # Singleton lifecycle # --------------------------------------------------------------------------- class TestSingletonLifecycle: def test_get_instance_before_init_raises(self): with pytest.raises(RuntimeError, match="not initialized"): ApiKeyService.get_instance() def test_is_initialized_false_before_init(self): assert ApiKeyService.is_initialized() is False def test_is_initialized_true_after_init(self): _make_service() assert ApiKeyService.is_initialized() is True def test_get_instance_returns_service(self): svc = _make_service() assert ApiKeyService.get_instance() is svc # --------------------------------------------------------------------------- # Basic validation # --------------------------------------------------------------------------- class TestBasicValidation: @pytest.mark.asyncio async def test_valid_key(self): svc = _make_service() mock_resp = _mock_response( 200, {"valid": True, "user_id": "user-1", "metadata": {"plan": "pro"}}) with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = AsyncMock(return_value=mock_resp) MockClient.return_value = instance result = await svc.validate("test-valid-key-12345678") assert result.valid is True assert result.user_id == "user-1" assert result.metadata == {"plan": "pro"} @pytest.mark.asyncio async def test_invalid_key_200_body(self): svc = _make_service() mock_resp = _mock_response( 200, {"valid": False, "error": "Key revoked"}) with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = AsyncMock(return_value=mock_resp) MockClient.return_value = instance result = await svc.validate("test-invalid-key-1234") assert result.valid is False assert result.error == "Key revoked" @pytest.mark.asyncio async def test_invalid_key_401_status(self): svc = _make_service() mock_resp = _mock_response(401) with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = AsyncMock(return_value=mock_resp) MockClient.return_value = instance result = await svc.validate("test-bad-key-12345678") assert result.valid is False assert "Invalid API key" in result.error @pytest.mark.asyncio async def test_empty_key_fast_path(self): svc = _make_service() with patch("httpx.AsyncClient") as MockClient: result = await svc.validate("") assert result.valid is False assert "required" in result.error.lower() # No HTTP call should have been made MockClient.assert_not_called() # --------------------------------------------------------------------------- # Caching # --------------------------------------------------------------------------- class TestCaching: @pytest.mark.asyncio async def test_cache_hit_valid_key(self): svc = _make_service(cache_ttl=300.0) mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"}) call_count = 0 async def counting_post(*args, **kwargs): nonlocal call_count call_count += 1 return mock_resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = counting_post MockClient.return_value = instance r1 = await svc.validate("test-cached-valid-key1") r2 = await svc.validate("test-cached-valid-key1") assert r1.valid is True assert r2.valid is True assert r2.user_id == "u1" assert call_count == 1 # Only one HTTP call @pytest.mark.asyncio async def test_cache_hit_invalid_key(self): svc = _make_service(cache_ttl=300.0) mock_resp = _mock_response(200, {"valid": False, "error": "bad"}) call_count = 0 async def counting_post(*args, **kwargs): nonlocal call_count call_count += 1 return mock_resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = counting_post MockClient.return_value = instance r1 = await svc.validate("test-cached-bad-key12") r2 = await svc.validate("test-cached-bad-key12") assert r1.valid is False assert r2.valid is False assert call_count == 1 @pytest.mark.asyncio async def test_cache_expiry(self): svc = _make_service(cache_ttl=1.0) # 1 second TTL mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"}) call_count = 0 async def counting_post(*args, **kwargs): nonlocal call_count call_count += 1 return mock_resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = counting_post MockClient.return_value = instance await svc.validate("test-expiry-key-12345") assert call_count == 1 # Manually expire the cache entry by manipulating the stored tuple async with svc._cache_lock: key = "test-expiry-key-12345" valid, user_id, metadata, _expires = svc._cache[key] svc._cache[key] = (valid, user_id, metadata, time.time() - 1) await svc.validate("test-expiry-key-12345") assert call_count == 2 # Had to re-validate @pytest.mark.asyncio async def test_invalidate_cache(self): svc = _make_service(cache_ttl=300.0) mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"}) call_count = 0 async def counting_post(*args, **kwargs): nonlocal call_count call_count += 1 return mock_resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = counting_post MockClient.return_value = instance await svc.validate("test-invalidate-key12") assert call_count == 1 await svc.invalidate_cache("test-invalidate-key12") await svc.validate("test-invalidate-key12") assert call_count == 2 @pytest.mark.asyncio async def test_clear_cache(self): svc = _make_service(cache_ttl=300.0) mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"}) call_count = 0 async def counting_post(*args, **kwargs): nonlocal call_count call_count += 1 return mock_resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = counting_post MockClient.return_value = instance await svc.validate("test-clear-key1-12345") await svc.validate("test-clear-key2-12345") assert call_count == 2 await svc.clear_cache() await svc.validate("test-clear-key1-12345") await svc.validate("test-clear-key2-12345") assert call_count == 4 # Both had to re-validate # --------------------------------------------------------------------------- # Transient failures & retries # --------------------------------------------------------------------------- class TestTransientFailures: @pytest.mark.asyncio async def test_5xx_not_cached(self): svc = _make_service(cache_ttl=300.0) mock_500 = _mock_response(500) mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"}) responses = [mock_500, mock_500, mock_ok] # Extra for retry call_idx = 0 async def sequential_post(*args, **kwargs): nonlocal call_idx resp = responses[min(call_idx, len(responses) - 1)] call_idx += 1 return resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = sequential_post MockClient.return_value = instance # First call: 500 -> not cached r1 = await svc.validate("test-5xx-test-key1234") assert r1.valid is False assert r1.cacheable is False # Second call should hit HTTP again (not cached) r2 = await svc.validate("test-5xx-test-key1234") # Second call also gets 500 from our mock sequence assert r2.valid is False @pytest.mark.asyncio async def test_timeout_then_retry_succeeds(self): svc = _make_service() mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"}) attempt = 0 async def timeout_then_ok(*args, **kwargs): nonlocal attempt attempt += 1 if attempt == 1: raise httpx.TimeoutException("timed out") return mock_ok with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = timeout_then_ok MockClient.return_value = instance result = await svc.validate("test-timeout-retry-ok") assert result.valid is True assert result.user_id == "u1" assert attempt == 2 @pytest.mark.asyncio async def test_timeout_exhausts_retries(self): svc = _make_service() async def always_timeout(*args, **kwargs): raise httpx.TimeoutException("timed out") with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = always_timeout MockClient.return_value = instance result = await svc.validate("test-timeout-exhaust1") assert result.valid is False assert "timeout" in result.error.lower() assert result.cacheable is False @pytest.mark.asyncio async def test_request_error_then_retry_succeeds(self): svc = _make_service() mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"}) attempt = 0 async def error_then_ok(*args, **kwargs): nonlocal attempt attempt += 1 if attempt == 1: raise httpx.ConnectError("connection refused") return mock_ok with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = error_then_ok MockClient.return_value = instance result = await svc.validate("test-reqerr-retry-ok1") assert result.valid is True assert attempt == 2 @pytest.mark.asyncio async def test_request_error_exhausts_retries(self): svc = _make_service() async def always_error(*args, **kwargs): raise httpx.ConnectError("connection refused") with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = always_error MockClient.return_value = instance result = await svc.validate("test-reqerr-exhaust1") assert result.valid is False assert "unavailable" in result.error.lower() assert result.cacheable is False @pytest.mark.asyncio async def test_unexpected_exception(self): svc = _make_service() async def unexpected(*args, **kwargs): raise ValueError("something unexpected") with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = unexpected MockClient.return_value = instance result = await svc.validate("test-unexpected-err12") assert result.valid is False assert result.cacheable is False # --------------------------------------------------------------------------- # Service token # --------------------------------------------------------------------------- class TestServiceToken: @pytest.mark.asyncio async def test_service_token_sent_in_headers(self): svc = _make_service( service_token_header="X-Service-Token", service_token="test-svc-token-123", ) mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"}) captured_headers = {} async def capture_post(url, *, json=None, headers=None): captured_headers.update(headers or {}) return mock_resp with patch("httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.post = capture_post MockClient.return_value = instance await svc.validate("test-svctoken-key1234") assert captured_headers.get("X-Service-Token") == "test-svc-token-123" assert captured_headers.get("Content-Type") == "application/json"