Archon/python/tests/test_async_embedding_service.py

470 lines
21 KiB
Python

"""
Comprehensive Tests for Async Embedding Service
Tests all aspects of the async embedding service after sync function removal.
Covers both success and error scenarios with thorough edge case testing.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import openai
import pytest
from src.server.services.embeddings.embedding_exceptions import (
EmbeddingAPIError,
)
from src.server.services.embeddings.embedding_service import (
EmbeddingBatchResult,
create_embedding,
create_embeddings_batch,
)
class AsyncContextManager:
"""Helper class for properly mocking async context managers"""
def __init__(self, return_value):
self.return_value = return_value
async def __aenter__(self):
return self.return_value
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
class TestAsyncEmbeddingService:
"""Test suite for async embedding service functions"""
@pytest.fixture
def mock_llm_client(self):
"""Mock LLM client for testing"""
mock_client = MagicMock()
mock_embeddings = MagicMock()
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3] + [0.0] * 1533) # 1536 dimensions
]
mock_embeddings.create = AsyncMock(return_value=mock_response)
mock_client.embeddings = mock_embeddings
return mock_client
@pytest.fixture
def mock_threading_service(self):
"""Mock threading service for testing"""
mock_service = MagicMock()
# Create a proper async context manager
rate_limit_ctx = AsyncContextManager(None)
mock_service.rate_limited_operation.return_value = rate_limit_ctx
return mock_service
@pytest.mark.asyncio
async def test_create_embedding_success(self, mock_llm_client, mock_threading_service):
"""Test successful single embedding creation"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
# Mock credential service properly
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup proper async context manager
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
result = await create_embedding("test text")
# Verify the result
assert len(result) == 1536
assert result[0] == 0.1
assert result[1] == 0.2
assert result[2] == 0.3
# Verify API was called correctly
mock_llm_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_embedding_empty_text(self, mock_llm_client, mock_threading_service):
"""Test embedding creation with empty text"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
result = await create_embedding("")
# Should still work with empty text
assert len(result) == 1536
mock_llm_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_embedding_api_error_raises_exception(self, mock_threading_service):
"""Test embedding creation with API error - should raise exception"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup client to raise an error
mock_client = MagicMock()
mock_client.embeddings.create = AsyncMock(
side_effect=Exception("API Error")
)
mock_get_client.return_value = AsyncContextManager(mock_client)
# Should raise exception now instead of returning zero embeddings
with pytest.raises(EmbeddingAPIError):
await create_embedding("test text")
@pytest.mark.asyncio
async def test_create_embeddings_batch_success(self, mock_llm_client, mock_threading_service):
"""Test successful batch embedding creation"""
# Setup mock response for multiple embeddings
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3] + [0.0] * 1533),
MagicMock(embedding=[0.4, 0.5, 0.6] + [0.0] * 1533),
]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
result = await create_embeddings_batch(["text1", "text2"])
# Verify the result is EmbeddingBatchResult
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 2
assert result.failure_count == 0
assert len(result.embeddings) == 2
assert len(result.embeddings[0]) == 1536
assert len(result.embeddings[1]) == 1536
assert result.embeddings[0][0] == 0.1
assert result.embeddings[1][0] == 0.4
mock_llm_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_embeddings_batch_empty_list(self):
"""Test batch embedding with empty list"""
result = await create_embeddings_batch([])
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 0
assert result.failure_count == 0
assert result.embeddings == []
@pytest.mark.asyncio
async def test_create_embeddings_batch_rate_limit_error(self, mock_threading_service):
"""Test batch embedding with rate limit error"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup client to raise rate limit error (not quota)
mock_client = MagicMock()
# Create a proper RateLimitError with required attributes
error = openai.RateLimitError(
"Rate limit exceeded",
response=MagicMock(),
body={"error": {"message": "Rate limit exceeded"}},
)
mock_client.embeddings.create = AsyncMock(side_effect=error)
mock_get_client.return_value = AsyncContextManager(mock_client)
result = await create_embeddings_batch(["text1", "text2"])
# Should return result with failures, not zero embeddings
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 0
assert result.failure_count == 2
assert len(result.embeddings) == 0
assert len(result.failed_items) == 2
@pytest.mark.asyncio
async def test_create_embeddings_batch_quota_exhausted(self, mock_threading_service):
"""Test batch embedding with quota exhausted error"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup client to raise quota exhausted error
mock_client = MagicMock()
error = openai.RateLimitError(
"insufficient_quota",
response=MagicMock(),
body={"error": {"message": "insufficient_quota"}},
)
mock_client.embeddings.create = AsyncMock(side_effect=error)
mock_get_client.return_value = AsyncContextManager(mock_client)
# Mock progress callback
progress_callback = AsyncMock()
result = await create_embeddings_batch(
["text1", "text2"], progress_callback=progress_callback
)
# Should return result with failures, not zero embeddings
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 0
assert result.failure_count == 2
assert len(result.embeddings) == 0
assert len(result.failed_items) == 2
# Verify quota exhausted is in error messages
assert any("quota" in item["error"].lower() for item in result.failed_items)
@pytest.mark.asyncio
async def test_create_embeddings_batch_with_websocket_progress(
self, mock_llm_client, mock_threading_service
):
"""Test batch embedding with WebSocket progress updates"""
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "1"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
# Mock WebSocket
mock_websocket = MagicMock()
mock_websocket.send_json = AsyncMock()
result = await create_embeddings_batch(["text1"], websocket=mock_websocket)
# Verify result is correct
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 1
# Verify WebSocket was called
mock_websocket.send_json.assert_called()
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "embedding_progress"
assert "processed" in call_args
assert "total" in call_args
@pytest.mark.asyncio
async def test_create_embeddings_batch_with_progress_callback(
self, mock_llm_client, mock_threading_service
):
"""Test batch embedding with progress callback"""
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "1"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
# Mock progress callback
progress_callback = AsyncMock()
result = await create_embeddings_batch(
["text1"], progress_callback=progress_callback
)
# Verify result
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 1
# Verify progress callback was called
progress_callback.assert_called()
@pytest.mark.asyncio
async def test_provider_override(self, mock_llm_client, mock_threading_service):
"""Test that provider override parameter is properly passed through"""
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model"
) as mock_get_model:
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
mock_get_model.return_value = "custom-model"
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
await create_embedding("test text", provider="custom-provider")
# Verify provider was passed to get_llm_client
mock_get_client.assert_called_with(
provider="custom-provider", use_embedding_provider=True
)
mock_get_model.assert_called_with(provider="custom-provider")
@pytest.mark.asyncio
async def test_create_embeddings_batch_large_batch_splitting(
self, mock_llm_client, mock_threading_service
):
"""Test that large batches are properly split according to batch size settings"""
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1] * 1536) for _ in range(2)
] # 2 embeddings per call
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
# Set batch size to 2
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "2"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
# Test with 5 texts (should require 3 API calls: 2+2+1)
texts = ["text1", "text2", "text3", "text4", "text5"]
result = await create_embeddings_batch(texts)
# Should have made 3 API calls due to batching
assert mock_llm_client.embeddings.create.call_count == 3
# Result should be EmbeddingBatchResult
assert isinstance(result, EmbeddingBatchResult)
# Should have 5 embeddings total (for 5 input texts)
# Even though mock returns 2 per call, we only process as many as we requested
assert result.success_count == 5
assert len(result.embeddings) == 5
assert result.texts_processed == texts