333 lines
14 KiB
Python
333 lines
14 KiB
Python
"""
|
|
Tests for embedding service to ensure no zero embeddings are returned.
|
|
|
|
These tests verify that the embedding service raises appropriate exceptions
|
|
instead of returning zero embeddings, following the "fail fast and loud" principle.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import openai
|
|
import pytest
|
|
|
|
from src.server.services.embeddings.embedding_exceptions import (
|
|
EmbeddingAPIError,
|
|
EmbeddingQuotaExhaustedError,
|
|
EmbeddingRateLimitError,
|
|
)
|
|
from src.server.services.embeddings.embedding_service import (
|
|
EmbeddingBatchResult,
|
|
create_embedding,
|
|
create_embeddings_batch,
|
|
)
|
|
|
|
|
|
class TestNoZeroEmbeddings:
|
|
"""Test that no zero embeddings are ever returned."""
|
|
|
|
# Note: Removed test_sync_from_async_context_raises_exception
|
|
# as sync versions no longer exist - everything is async-only now
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_quota_exhausted_returns_failure(self) -> None:
|
|
"""Test that quota exhaustion returns failure result instead of zeros."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock the client to raise quota error
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
|
|
"insufficient_quota: You have exceeded your quota", response=Mock(), body=None
|
|
)
|
|
mock_client.return_value = mock_ctx
|
|
|
|
# Single embedding still raises for backward compatibility
|
|
with pytest.raises(EmbeddingQuotaExhaustedError) as exc_info:
|
|
await create_embedding("test text")
|
|
|
|
assert "quota exhausted" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_rate_limit_raises_exception(self) -> None:
|
|
"""Test that rate limit errors raise exception after retries."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock the client to raise rate limit error
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
|
|
"rate_limit_exceeded: Too many requests", response=Mock(), body=None
|
|
)
|
|
mock_client.return_value = mock_ctx
|
|
|
|
with pytest.raises(EmbeddingRateLimitError) as exc_info:
|
|
await create_embedding("test text")
|
|
|
|
assert "rate limit" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_api_error_raises_exception(self) -> None:
|
|
"""Test that API errors raise exception instead of returning zeros."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock the client to raise generic error
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception(
|
|
"Network error"
|
|
)
|
|
mock_client.return_value = mock_ctx
|
|
|
|
with pytest.raises(EmbeddingAPIError) as exc_info:
|
|
await create_embedding("test text")
|
|
|
|
assert "failed to create embedding" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_handles_partial_failures(self) -> None:
|
|
"""Test that batch processing can handle partial failures gracefully."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock successful response for first batch, failure for second
|
|
mock_ctx = AsyncMock()
|
|
mock_response = Mock()
|
|
mock_response.data = [Mock(embedding=[0.1] * 1536), Mock(embedding=[0.2] * 1536)]
|
|
|
|
# First call succeeds, second fails
|
|
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = [
|
|
mock_response,
|
|
Exception("API Error"),
|
|
]
|
|
mock_client.return_value = mock_ctx
|
|
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
|
new_callable=AsyncMock,
|
|
return_value="text-embedding-ada-002",
|
|
):
|
|
# Mock credential service to return batch size of 2
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
|
|
new_callable=AsyncMock,
|
|
return_value={"EMBEDDING_BATCH_SIZE": "2"},
|
|
):
|
|
# Process 4 texts (batch size will be 2)
|
|
texts = ["text1", "text2", "text3", "text4"]
|
|
result = await create_embeddings_batch(texts)
|
|
|
|
# Check result structure
|
|
assert isinstance(result, EmbeddingBatchResult)
|
|
assert result.success_count == 2 # First batch succeeded
|
|
assert result.failure_count == 2 # Second batch failed
|
|
assert len(result.embeddings) == 2
|
|
assert len(result.failed_items) == 2
|
|
|
|
# Verify no zero embeddings were created
|
|
for embedding in result.embeddings:
|
|
assert not all(v == 0.0 for v in embedding)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_configurable_embedding_dimensions(self) -> None:
|
|
"""Test that embedding dimensions can be configured via settings."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock successful response
|
|
mock_ctx = AsyncMock()
|
|
mock_create = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create = mock_create
|
|
|
|
# Setup mock response
|
|
mock_response = Mock()
|
|
mock_response.data = [Mock(embedding=[0.1] * 3072)] # Different dimensions
|
|
mock_create.return_value = mock_response
|
|
mock_client.return_value = mock_ctx
|
|
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
|
new_callable=AsyncMock,
|
|
return_value="text-embedding-3-large",
|
|
):
|
|
# Mock credential service to return custom dimensions
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
|
|
new_callable=AsyncMock,
|
|
return_value={"EMBEDDING_DIMENSIONS": "3072"},
|
|
):
|
|
result = await create_embeddings_batch(["test text"])
|
|
|
|
# Verify the dimensions parameter was passed correctly
|
|
mock_create.assert_called_once()
|
|
call_args = mock_create.call_args
|
|
assert call_args.kwargs["dimensions"] == 3072
|
|
|
|
# Verify result
|
|
assert result.success_count == 1
|
|
assert len(result.embeddings[0]) == 3072
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_default_embedding_dimensions(self) -> None:
|
|
"""Test that default dimensions (1536) are used when not configured."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock successful response
|
|
mock_ctx = AsyncMock()
|
|
mock_create = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create = mock_create
|
|
|
|
# Setup mock response with default dimensions
|
|
mock_response = Mock()
|
|
mock_response.data = [Mock(embedding=[0.1] * 1536)]
|
|
mock_create.return_value = mock_response
|
|
mock_client.return_value = mock_ctx
|
|
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
|
new_callable=AsyncMock,
|
|
return_value="text-embedding-3-small",
|
|
):
|
|
# Mock credential service to return empty settings (no dimensions specified)
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
|
|
new_callable=AsyncMock,
|
|
return_value={},
|
|
):
|
|
result = await create_embeddings_batch(["test text"])
|
|
|
|
# Verify the default dimensions parameter was used
|
|
mock_create.assert_called_once()
|
|
call_args = mock_create.call_args
|
|
assert call_args.kwargs["dimensions"] == 1536
|
|
|
|
# Verify result
|
|
assert result.success_count == 1
|
|
assert len(result.embeddings[0]) == 1536
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_quota_exhausted_stops_process(self) -> None:
|
|
"""Test that quota exhaustion stops processing remaining batches."""
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock quota exhaustion
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
|
|
"insufficient_quota: Quota exceeded", response=Mock(), body=None
|
|
)
|
|
mock_client.return_value = mock_ctx
|
|
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
|
new_callable=AsyncMock,
|
|
return_value="text-embedding-ada-002",
|
|
):
|
|
texts = ["text1", "text2", "text3", "text4"]
|
|
result = await create_embeddings_batch(texts)
|
|
|
|
# All should fail due to quota
|
|
assert result.success_count == 0
|
|
assert result.failure_count == 4
|
|
assert len(result.embeddings) == 0
|
|
assert all("quota" in item["error"].lower() for item in result.failed_items)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_zero_vectors_in_results(self) -> None:
|
|
"""Test that no function ever returns a zero vector [0.0] * 1536."""
|
|
# This is a meta-test to ensure our implementation never creates zero vectors
|
|
|
|
# Helper to check if a value is a zero embedding
|
|
def is_zero_embedding(value):
|
|
if not isinstance(value, list):
|
|
return False
|
|
if len(value) != 1536:
|
|
return False
|
|
return all(v == 0.0 for v in value)
|
|
|
|
# Test data that should never produce zero embeddings
|
|
test_text = "This is a test"
|
|
|
|
# Test: Batch function with error should return failure result, not zeros
|
|
with patch(
|
|
"src.server.services.embeddings.embedding_service.get_llm_client"
|
|
) as mock_client:
|
|
# Mock the client to raise an error
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception("Test error")
|
|
mock_client.return_value = mock_ctx
|
|
|
|
result = await create_embeddings_batch([test_text])
|
|
# Should return result with failures, not zeros
|
|
assert isinstance(result, EmbeddingBatchResult)
|
|
assert len(result.embeddings) == 0
|
|
assert result.failure_count == 1
|
|
# Verify no zero embeddings in the result
|
|
for embedding in result.embeddings:
|
|
assert not is_zero_embedding(embedding)
|
|
|
|
|
|
class TestEmbeddingBatchResult:
|
|
"""Test the EmbeddingBatchResult dataclass."""
|
|
|
|
def test_batch_result_initialization(self) -> None:
|
|
"""Test that EmbeddingBatchResult initializes correctly."""
|
|
result = EmbeddingBatchResult()
|
|
assert result.success_count == 0
|
|
assert result.failure_count == 0
|
|
assert result.embeddings == []
|
|
assert result.failed_items == []
|
|
assert not result.has_failures
|
|
|
|
def test_batch_result_add_success(self) -> None:
|
|
"""Test adding successful embeddings."""
|
|
result = EmbeddingBatchResult()
|
|
embedding = [0.1] * 1536
|
|
text = "test text"
|
|
|
|
result.add_success(embedding, text)
|
|
|
|
assert result.success_count == 1
|
|
assert result.failure_count == 0
|
|
assert len(result.embeddings) == 1
|
|
assert result.embeddings[0] == embedding
|
|
assert result.texts_processed[0] == text
|
|
assert not result.has_failures
|
|
|
|
def test_batch_result_add_failure(self) -> None:
|
|
"""Test adding failed items."""
|
|
result = EmbeddingBatchResult()
|
|
error = EmbeddingAPIError("Test error", text_preview="test")
|
|
|
|
result.add_failure("test text", error, batch_index=0)
|
|
|
|
assert result.success_count == 0
|
|
assert result.failure_count == 1
|
|
assert len(result.failed_items) == 1
|
|
assert result.has_failures
|
|
|
|
failed_item = result.failed_items[0]
|
|
assert failed_item["error"] == "Test error"
|
|
assert failed_item["error_type"] == "EmbeddingAPIError"
|
|
# batch_index comes from the error's to_dict() method which includes it
|
|
assert "batch_index" in failed_item # Just check it exists
|
|
|
|
def test_batch_result_mixed_results(self) -> None:
|
|
"""Test batch result with both successes and failures."""
|
|
result = EmbeddingBatchResult()
|
|
|
|
# Add successes
|
|
result.add_success([0.1] * 1536, "text1")
|
|
result.add_success([0.2] * 1536, "text2")
|
|
|
|
# Add failures
|
|
result.add_failure("text3", Exception("Error 1"), 1)
|
|
result.add_failure("text4", Exception("Error 2"), 1)
|
|
|
|
assert result.success_count == 2
|
|
assert result.failure_count == 2
|
|
assert result.total_requested == 4
|
|
assert result.has_failures
|
|
assert len(result.embeddings) == 2
|
|
assert len(result.failed_items) == 2
|