""" Tests for embedding service to ensure no zero embeddings are returned. These tests verify that the embedding service raises appropriate exceptions instead of returning zero embeddings, following the "fail fast and loud" principle. """ from unittest.mock import AsyncMock, Mock, patch import openai import pytest from src.server.services.embeddings.embedding_exceptions import ( EmbeddingAPIError, EmbeddingQuotaExhaustedError, EmbeddingRateLimitError, ) from src.server.services.embeddings.embedding_service import ( EmbeddingBatchResult, create_embedding, create_embeddings_batch, ) class TestNoZeroEmbeddings: """Test that no zero embeddings are ever returned.""" # Note: Removed test_sync_from_async_context_raises_exception # as sync versions no longer exist - everything is async-only now @pytest.mark.asyncio async def test_async_quota_exhausted_returns_failure(self) -> None: """Test that quota exhaustion returns failure result instead of zeros.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock the client to raise quota error mock_ctx = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError( "insufficient_quota: You have exceeded your quota", response=Mock(), body=None ) mock_client.return_value = mock_ctx # Single embedding still raises for backward compatibility with pytest.raises(EmbeddingQuotaExhaustedError) as exc_info: await create_embedding("test text") assert "quota exhausted" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_async_rate_limit_raises_exception(self) -> None: """Test that rate limit errors raise exception after retries.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock the client to raise rate limit error mock_ctx = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError( "rate_limit_exceeded: Too many requests", response=Mock(), body=None ) mock_client.return_value = mock_ctx with pytest.raises(EmbeddingRateLimitError) as exc_info: await create_embedding("test text") assert "rate limit" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_async_api_error_raises_exception(self) -> None: """Test that API errors raise exception instead of returning zeros.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock the client to raise generic error mock_ctx = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception( "Network error" ) mock_client.return_value = mock_ctx with pytest.raises(EmbeddingAPIError) as exc_info: await create_embedding("test text") assert "failed to create embedding" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_batch_handles_partial_failures(self) -> None: """Test that batch processing can handle partial failures gracefully.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock successful response for first batch, failure for second mock_ctx = AsyncMock() mock_response = Mock() mock_response.data = [Mock(embedding=[0.1] * 1536), Mock(embedding=[0.2] * 1536)] # First call succeeds, second fails mock_ctx.__aenter__.return_value.embeddings.create.side_effect = [ mock_response, Exception("API Error"), ] mock_client.return_value = mock_ctx with patch( "src.server.services.embeddings.embedding_service.get_embedding_model", new_callable=AsyncMock, return_value="text-embedding-ada-002", ): # Mock credential service to return batch size of 2 with patch( "src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category", new_callable=AsyncMock, return_value={"EMBEDDING_BATCH_SIZE": "2"}, ): # Process 4 texts (batch size will be 2) texts = ["text1", "text2", "text3", "text4"] result = await create_embeddings_batch(texts) # Check result structure assert isinstance(result, EmbeddingBatchResult) assert result.success_count == 2 # First batch succeeded assert result.failure_count == 2 # Second batch failed assert len(result.embeddings) == 2 assert len(result.failed_items) == 2 # Verify no zero embeddings were created for embedding in result.embeddings: assert not all(v == 0.0 for v in embedding) @pytest.mark.asyncio async def test_configurable_embedding_dimensions(self) -> None: """Test that embedding dimensions can be configured via settings.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock successful response mock_ctx = AsyncMock() mock_create = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create = mock_create # Setup mock response mock_response = Mock() mock_response.data = [Mock(embedding=[0.1] * 3072)] # Different dimensions mock_create.return_value = mock_response mock_client.return_value = mock_ctx with patch( "src.server.services.embeddings.embedding_service.get_embedding_model", new_callable=AsyncMock, return_value="text-embedding-3-large", ): # Mock credential service to return custom dimensions with patch( "src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category", new_callable=AsyncMock, return_value={"EMBEDDING_DIMENSIONS": "3072"}, ): result = await create_embeddings_batch(["test text"]) # Verify the dimensions parameter was passed correctly mock_create.assert_called_once() call_args = mock_create.call_args assert call_args.kwargs["dimensions"] == 3072 # Verify result assert result.success_count == 1 assert len(result.embeddings[0]) == 3072 @pytest.mark.asyncio async def test_default_embedding_dimensions(self) -> None: """Test that default dimensions (1536) are used when not configured.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock successful response mock_ctx = AsyncMock() mock_create = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create = mock_create # Setup mock response with default dimensions mock_response = Mock() mock_response.data = [Mock(embedding=[0.1] * 1536)] mock_create.return_value = mock_response mock_client.return_value = mock_ctx with patch( "src.server.services.embeddings.embedding_service.get_embedding_model", new_callable=AsyncMock, return_value="text-embedding-3-small", ): # Mock credential service to return empty settings (no dimensions specified) with patch( "src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category", new_callable=AsyncMock, return_value={}, ): result = await create_embeddings_batch(["test text"]) # Verify the default dimensions parameter was used mock_create.assert_called_once() call_args = mock_create.call_args assert call_args.kwargs["dimensions"] == 1536 # Verify result assert result.success_count == 1 assert len(result.embeddings[0]) == 1536 @pytest.mark.asyncio async def test_batch_quota_exhausted_stops_process(self) -> None: """Test that quota exhaustion stops processing remaining batches.""" with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock quota exhaustion mock_ctx = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError( "insufficient_quota: Quota exceeded", response=Mock(), body=None ) mock_client.return_value = mock_ctx with patch( "src.server.services.embeddings.embedding_service.get_embedding_model", new_callable=AsyncMock, return_value="text-embedding-ada-002", ): texts = ["text1", "text2", "text3", "text4"] result = await create_embeddings_batch(texts) # All should fail due to quota assert result.success_count == 0 assert result.failure_count == 4 assert len(result.embeddings) == 0 assert all("quota" in item["error"].lower() for item in result.failed_items) @pytest.mark.asyncio async def test_no_zero_vectors_in_results(self) -> None: """Test that no function ever returns a zero vector [0.0] * 1536.""" # This is a meta-test to ensure our implementation never creates zero vectors # Helper to check if a value is a zero embedding def is_zero_embedding(value): if not isinstance(value, list): return False if len(value) != 1536: return False return all(v == 0.0 for v in value) # Test data that should never produce zero embeddings test_text = "This is a test" # Test: Batch function with error should return failure result, not zeros with patch( "src.server.services.embeddings.embedding_service.get_llm_client" ) as mock_client: # Mock the client to raise an error mock_ctx = AsyncMock() mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception("Test error") mock_client.return_value = mock_ctx result = await create_embeddings_batch([test_text]) # Should return result with failures, not zeros assert isinstance(result, EmbeddingBatchResult) assert len(result.embeddings) == 0 assert result.failure_count == 1 # Verify no zero embeddings in the result for embedding in result.embeddings: assert not is_zero_embedding(embedding) class TestEmbeddingBatchResult: """Test the EmbeddingBatchResult dataclass.""" def test_batch_result_initialization(self) -> None: """Test that EmbeddingBatchResult initializes correctly.""" result = EmbeddingBatchResult() assert result.success_count == 0 assert result.failure_count == 0 assert result.embeddings == [] assert result.failed_items == [] assert not result.has_failures def test_batch_result_add_success(self) -> None: """Test adding successful embeddings.""" result = EmbeddingBatchResult() embedding = [0.1] * 1536 text = "test text" result.add_success(embedding, text) assert result.success_count == 1 assert result.failure_count == 0 assert len(result.embeddings) == 1 assert result.embeddings[0] == embedding assert result.texts_processed[0] == text assert not result.has_failures def test_batch_result_add_failure(self) -> None: """Test adding failed items.""" result = EmbeddingBatchResult() error = EmbeddingAPIError("Test error", text_preview="test") result.add_failure("test text", error, batch_index=0) assert result.success_count == 0 assert result.failure_count == 1 assert len(result.failed_items) == 1 assert result.has_failures failed_item = result.failed_items[0] assert failed_item["error"] == "Test error" assert failed_item["error_type"] == "EmbeddingAPIError" # batch_index comes from the error's to_dict() method which includes it assert "batch_index" in failed_item # Just check it exists def test_batch_result_mixed_results(self) -> None: """Test batch result with both successes and failures.""" result = EmbeddingBatchResult() # Add successes result.add_success([0.1] * 1536, "text1") result.add_success([0.2] * 1536, "text2") # Add failures result.add_failure("text3", Exception("Error 1"), 1) result.add_failure("text4", Exception("Error 2"), 1) assert result.success_count == 2 assert result.failure_count == 2 assert result.total_requested == 4 assert result.has_failures assert len(result.embeddings) == 2 assert len(result.failed_items) == 2