457 lines
16 KiB
Python
457 lines
16 KiB
Python
|
|
"""Tests for ApiKeyService: validation, caching, retries, and singleton lifecycle."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import time
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from services.api_key_service import ApiKeyService, ValidationResult
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def _reset_singleton():
|
||
|
|
"""Reset the ApiKeyService singleton between tests."""
|
||
|
|
ApiKeyService._instance = None
|
||
|
|
yield
|
||
|
|
ApiKeyService._instance = None
|
||
|
|
|
||
|
|
|
||
|
|
def _make_service(
|
||
|
|
validation_url="https://auth.example.com/validate",
|
||
|
|
cache_ttl=300.0,
|
||
|
|
service_token_header=None,
|
||
|
|
service_token=None,
|
||
|
|
):
|
||
|
|
return ApiKeyService(
|
||
|
|
validation_url=validation_url,
|
||
|
|
cache_ttl=cache_ttl,
|
||
|
|
service_token_header=service_token_header,
|
||
|
|
service_token=service_token,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _mock_response(status_code=200, json_data=None):
|
||
|
|
resp = MagicMock(spec=httpx.Response)
|
||
|
|
resp.status_code = status_code
|
||
|
|
resp.json.return_value = json_data or {}
|
||
|
|
return resp
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Singleton lifecycle
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestSingletonLifecycle:
|
||
|
|
def test_get_instance_before_init_raises(self):
|
||
|
|
with pytest.raises(RuntimeError, match="not initialized"):
|
||
|
|
ApiKeyService.get_instance()
|
||
|
|
|
||
|
|
def test_is_initialized_false_before_init(self):
|
||
|
|
assert ApiKeyService.is_initialized() is False
|
||
|
|
|
||
|
|
def test_is_initialized_true_after_init(self):
|
||
|
|
_make_service()
|
||
|
|
assert ApiKeyService.is_initialized() is True
|
||
|
|
|
||
|
|
def test_get_instance_returns_service(self):
|
||
|
|
svc = _make_service()
|
||
|
|
assert ApiKeyService.get_instance() is svc
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Basic validation
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestBasicValidation:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_valid_key(self):
|
||
|
|
svc = _make_service()
|
||
|
|
mock_resp = _mock_response(
|
||
|
|
200, {"valid": True, "user_id": "user-1", "metadata": {"plan": "pro"}})
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = AsyncMock(return_value=mock_resp)
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-valid-key-12345678")
|
||
|
|
|
||
|
|
assert result.valid is True
|
||
|
|
assert result.user_id == "user-1"
|
||
|
|
assert result.metadata == {"plan": "pro"}
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_invalid_key_200_body(self):
|
||
|
|
svc = _make_service()
|
||
|
|
mock_resp = _mock_response(
|
||
|
|
200, {"valid": False, "error": "Key revoked"})
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = AsyncMock(return_value=mock_resp)
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-invalid-key-1234")
|
||
|
|
|
||
|
|
assert result.valid is False
|
||
|
|
assert result.error == "Key revoked"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_invalid_key_401_status(self):
|
||
|
|
svc = _make_service()
|
||
|
|
mock_resp = _mock_response(401)
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = AsyncMock(return_value=mock_resp)
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-bad-key-12345678")
|
||
|
|
|
||
|
|
assert result.valid is False
|
||
|
|
assert "Invalid API key" in result.error
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_empty_key_fast_path(self):
|
||
|
|
svc = _make_service()
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
result = await svc.validate("")
|
||
|
|
|
||
|
|
assert result.valid is False
|
||
|
|
assert "required" in result.error.lower()
|
||
|
|
# No HTTP call should have been made
|
||
|
|
MockClient.assert_not_called()
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Caching
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestCaching:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cache_hit_valid_key(self):
|
||
|
|
svc = _make_service(cache_ttl=300.0)
|
||
|
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
call_count = 0
|
||
|
|
|
||
|
|
async def counting_post(*args, **kwargs):
|
||
|
|
nonlocal call_count
|
||
|
|
call_count += 1
|
||
|
|
return mock_resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = counting_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
r1 = await svc.validate("test-cached-valid-key1")
|
||
|
|
r2 = await svc.validate("test-cached-valid-key1")
|
||
|
|
|
||
|
|
assert r1.valid is True
|
||
|
|
assert r2.valid is True
|
||
|
|
assert r2.user_id == "u1"
|
||
|
|
assert call_count == 1 # Only one HTTP call
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cache_hit_invalid_key(self):
|
||
|
|
svc = _make_service(cache_ttl=300.0)
|
||
|
|
mock_resp = _mock_response(200, {"valid": False, "error": "bad"})
|
||
|
|
call_count = 0
|
||
|
|
|
||
|
|
async def counting_post(*args, **kwargs):
|
||
|
|
nonlocal call_count
|
||
|
|
call_count += 1
|
||
|
|
return mock_resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = counting_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
r1 = await svc.validate("test-cached-bad-key12")
|
||
|
|
r2 = await svc.validate("test-cached-bad-key12")
|
||
|
|
|
||
|
|
assert r1.valid is False
|
||
|
|
assert r2.valid is False
|
||
|
|
assert call_count == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cache_expiry(self):
|
||
|
|
svc = _make_service(cache_ttl=1.0) # 1 second TTL
|
||
|
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
call_count = 0
|
||
|
|
|
||
|
|
async def counting_post(*args, **kwargs):
|
||
|
|
nonlocal call_count
|
||
|
|
call_count += 1
|
||
|
|
return mock_resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = counting_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
await svc.validate("test-expiry-key-12345")
|
||
|
|
assert call_count == 1
|
||
|
|
|
||
|
|
# Manually expire the cache entry by manipulating the stored tuple
|
||
|
|
async with svc._cache_lock:
|
||
|
|
key = "test-expiry-key-12345"
|
||
|
|
valid, user_id, metadata, _expires = svc._cache[key]
|
||
|
|
svc._cache[key] = (valid, user_id, metadata, time.time() - 1)
|
||
|
|
|
||
|
|
await svc.validate("test-expiry-key-12345")
|
||
|
|
assert call_count == 2 # Had to re-validate
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_invalidate_cache(self):
|
||
|
|
svc = _make_service(cache_ttl=300.0)
|
||
|
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
call_count = 0
|
||
|
|
|
||
|
|
async def counting_post(*args, **kwargs):
|
||
|
|
nonlocal call_count
|
||
|
|
call_count += 1
|
||
|
|
return mock_resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = counting_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
await svc.validate("test-invalidate-key12")
|
||
|
|
assert call_count == 1
|
||
|
|
|
||
|
|
await svc.invalidate_cache("test-invalidate-key12")
|
||
|
|
|
||
|
|
await svc.validate("test-invalidate-key12")
|
||
|
|
assert call_count == 2
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_clear_cache(self):
|
||
|
|
svc = _make_service(cache_ttl=300.0)
|
||
|
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
call_count = 0
|
||
|
|
|
||
|
|
async def counting_post(*args, **kwargs):
|
||
|
|
nonlocal call_count
|
||
|
|
call_count += 1
|
||
|
|
return mock_resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = counting_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
await svc.validate("test-clear-key1-12345")
|
||
|
|
await svc.validate("test-clear-key2-12345")
|
||
|
|
assert call_count == 2
|
||
|
|
|
||
|
|
await svc.clear_cache()
|
||
|
|
|
||
|
|
await svc.validate("test-clear-key1-12345")
|
||
|
|
await svc.validate("test-clear-key2-12345")
|
||
|
|
assert call_count == 4 # Both had to re-validate
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Transient failures & retries
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestTransientFailures:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_5xx_not_cached(self):
|
||
|
|
svc = _make_service(cache_ttl=300.0)
|
||
|
|
mock_500 = _mock_response(500)
|
||
|
|
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
responses = [mock_500, mock_500, mock_ok] # Extra for retry
|
||
|
|
call_idx = 0
|
||
|
|
|
||
|
|
async def sequential_post(*args, **kwargs):
|
||
|
|
nonlocal call_idx
|
||
|
|
resp = responses[min(call_idx, len(responses) - 1)]
|
||
|
|
call_idx += 1
|
||
|
|
return resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = sequential_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
# First call: 500 -> not cached
|
||
|
|
r1 = await svc.validate("test-5xx-test-key1234")
|
||
|
|
assert r1.valid is False
|
||
|
|
assert r1.cacheable is False
|
||
|
|
|
||
|
|
# Second call should hit HTTP again (not cached)
|
||
|
|
r2 = await svc.validate("test-5xx-test-key1234")
|
||
|
|
# Second call also gets 500 from our mock sequence
|
||
|
|
assert r2.valid is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_timeout_then_retry_succeeds(self):
|
||
|
|
svc = _make_service()
|
||
|
|
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
attempt = 0
|
||
|
|
|
||
|
|
async def timeout_then_ok(*args, **kwargs):
|
||
|
|
nonlocal attempt
|
||
|
|
attempt += 1
|
||
|
|
if attempt == 1:
|
||
|
|
raise httpx.TimeoutException("timed out")
|
||
|
|
return mock_ok
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = timeout_then_ok
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-timeout-retry-ok")
|
||
|
|
|
||
|
|
assert result.valid is True
|
||
|
|
assert result.user_id == "u1"
|
||
|
|
assert attempt == 2
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_timeout_exhausts_retries(self):
|
||
|
|
svc = _make_service()
|
||
|
|
|
||
|
|
async def always_timeout(*args, **kwargs):
|
||
|
|
raise httpx.TimeoutException("timed out")
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = always_timeout
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-timeout-exhaust1")
|
||
|
|
|
||
|
|
assert result.valid is False
|
||
|
|
assert "timeout" in result.error.lower()
|
||
|
|
assert result.cacheable is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_request_error_then_retry_succeeds(self):
|
||
|
|
svc = _make_service()
|
||
|
|
mock_ok = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
attempt = 0
|
||
|
|
|
||
|
|
async def error_then_ok(*args, **kwargs):
|
||
|
|
nonlocal attempt
|
||
|
|
attempt += 1
|
||
|
|
if attempt == 1:
|
||
|
|
raise httpx.ConnectError("connection refused")
|
||
|
|
return mock_ok
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = error_then_ok
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-reqerr-retry-ok1")
|
||
|
|
|
||
|
|
assert result.valid is True
|
||
|
|
assert attempt == 2
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_request_error_exhausts_retries(self):
|
||
|
|
svc = _make_service()
|
||
|
|
|
||
|
|
async def always_error(*args, **kwargs):
|
||
|
|
raise httpx.ConnectError("connection refused")
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = always_error
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-reqerr-exhaust1")
|
||
|
|
|
||
|
|
assert result.valid is False
|
||
|
|
assert "unavailable" in result.error.lower()
|
||
|
|
assert result.cacheable is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_unexpected_exception(self):
|
||
|
|
svc = _make_service()
|
||
|
|
|
||
|
|
async def unexpected(*args, **kwargs):
|
||
|
|
raise ValueError("something unexpected")
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = unexpected
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
result = await svc.validate("test-unexpected-err12")
|
||
|
|
|
||
|
|
assert result.valid is False
|
||
|
|
assert result.cacheable is False
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Service token
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestServiceToken:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_service_token_sent_in_headers(self):
|
||
|
|
svc = _make_service(
|
||
|
|
service_token_header="X-Service-Token",
|
||
|
|
service_token="test-svc-token-123",
|
||
|
|
)
|
||
|
|
mock_resp = _mock_response(200, {"valid": True, "user_id": "u1"})
|
||
|
|
captured_headers = {}
|
||
|
|
|
||
|
|
async def capture_post(url, *, json=None, headers=None):
|
||
|
|
captured_headers.update(headers or {})
|
||
|
|
return mock_resp
|
||
|
|
|
||
|
|
with patch("httpx.AsyncClient") as MockClient:
|
||
|
|
instance = AsyncMock()
|
||
|
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||
|
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||
|
|
instance.post = capture_post
|
||
|
|
MockClient.return_value = instance
|
||
|
|
|
||
|
|
await svc.validate("test-svctoken-key1234")
|
||
|
|
|
||
|
|
assert captured_headers.get("X-Service-Token") == "test-svc-token-123"
|
||
|
|
assert captured_headers.get("Content-Type") == "application/json"
|