diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py index 47b5520..7051678 100644 --- a/python/src/server/services/embeddings/embedding_service.py +++ b/python/src/server/services/embeddings/embedding_service.py @@ -205,8 +205,25 @@ async def create_embeddings_batch( batch_tokens = sum(len(text.split()) for text in batch) * 1.3 total_tokens_used += batch_tokens + # Create rate limit progress callback if we have a progress callback + rate_limit_callback = None + if progress_callback or websocket: + async def rate_limit_callback(data: dict): + # Send heartbeat during rate limit wait + if progress_callback: + processed = result.success_count + result.failure_count + message = f"Rate limited: {data.get('message', 'Waiting...')}" + await progress_callback(message, (processed / len(texts)) * 100) + + if websocket: + await websocket.send_json({ + "type": "rate_limit_wait", + "message": data.get("message", "Rate limited, waiting..."), + "remaining_seconds": data.get("remaining_seconds", 0) + }) + # Rate limit each batch - async with threading_service.rate_limited_operation(batch_tokens): + async with threading_service.rate_limited_operation(batch_tokens, rate_limit_callback): retry_count = 0 max_retries = 3 diff --git a/python/src/server/services/storage/document_storage_service.py b/python/src/server/services/storage/document_storage_service.py index 340870e..24a0732 100644 --- a/python/src/server/services/storage/document_storage_service.py +++ b/python/src/server/services/storage/document_storage_service.py @@ -233,9 +233,23 @@ async def add_documents_to_supabase( # If not using contextual embeddings, use original contents contextual_contents = batch_contents - # Create embeddings for the batch - no progress reporting - # Don't pass websocket to avoid Socket.IO issues - result = await create_embeddings_batch(contextual_contents, provider=provider) + # Create embeddings for the batch with rate limit progress support + # Create a wrapper for progress callback to handle rate limiting updates + async def embedding_progress_wrapper(message: str, percentage: float): + # Forward rate limiting messages to the main progress callback + if progress_callback and "rate limit" in message.lower(): + await progress_callback( + message, + current_percentage, # Use current batch progress + {"batch": batch_num, "type": "rate_limit_wait"} + ) + + # Pass progress callback for rate limiting updates + result = await create_embeddings_batch( + contextual_contents, + provider=provider, + progress_callback=embedding_progress_wrapper if progress_callback else None + ) # Log any failures if result.has_failures: diff --git a/python/src/server/services/threading_service.py b/python/src/server/services/threading_service.py index b3a0053..6e26581 100644 --- a/python/src/server/services/threading_service.py +++ b/python/src/server/services/threading_service.py @@ -84,16 +84,30 @@ class RateLimiter: self.semaphore = asyncio.Semaphore(config.max_concurrent) self._lock = asyncio.Lock() - async def acquire(self, estimated_tokens: int = 8000) -> bool: - """Acquire permission to make API call with token awareness""" - async with self._lock: - now = time.time() + async def acquire(self, estimated_tokens: int = 8000, progress_callback: Callable | None = None) -> bool: + """Acquire permission to make API call with token awareness + + Args: + estimated_tokens: Estimated number of tokens for the operation + progress_callback: Optional async callback for progress updates during wait + """ + while True: # Loop instead of recursion to avoid stack overflow + wait_time_to_sleep = None + + async with self._lock: + now = time.time() - # Clean old entries - self._clean_old_entries(now) + # Clean old entries + self._clean_old_entries(now) - # Check if we can make the request - if not self._can_make_request(estimated_tokens): + # Check if we can make the request + if self._can_make_request(estimated_tokens): + # Record the request + self.request_times.append(now) + self.token_usage.append((now, estimated_tokens)) + return True + + # Calculate wait time if we can't make the request wait_time = self._calculate_wait_time(estimated_tokens) if wait_time > 0: logfire_logger.info( @@ -103,14 +117,30 @@ class RateLimiter: "current_usage": self._get_current_usage(), } ) - await asyncio.sleep(wait_time) - return await self.acquire(estimated_tokens) - return False - - # Record the request - self.request_times.append(now) - self.token_usage.append((now, estimated_tokens)) - return True + wait_time_to_sleep = wait_time + else: + return False + + # Sleep outside the lock to avoid deadlock + if wait_time_to_sleep is not None: + # For long waits, break into smaller chunks with progress updates + if wait_time_to_sleep > 5 and progress_callback: + chunks = int(wait_time_to_sleep / 5) # 5 second chunks + for i in range(chunks): + await asyncio.sleep(5) + remaining = wait_time_to_sleep - (i + 1) * 5 + if progress_callback: + await progress_callback({ + "type": "rate_limit_wait", + "remaining_seconds": max(0, remaining), + "message": f"waiting {max(0, remaining):.1f}s more..." + }) + # Sleep any remaining time + if wait_time_to_sleep % 5 > 0: + await asyncio.sleep(wait_time_to_sleep % 5) + else: + await asyncio.sleep(wait_time_to_sleep) + # Continue the loop to try again def _can_make_request(self, estimated_tokens: int) -> bool: """Check if request can be made within limits""" @@ -510,10 +540,15 @@ class ThreadingService: logfire_logger.info("Threading service stopped") @asynccontextmanager - async def rate_limited_operation(self, estimated_tokens: int = 8000): - """Context manager for rate-limited operations""" + async def rate_limited_operation(self, estimated_tokens: int = 8000, progress_callback: Callable | None = None): + """Context manager for rate-limited operations + + Args: + estimated_tokens: Estimated number of tokens for the operation + progress_callback: Optional async callback for progress updates during wait + """ async with self.rate_limiter.semaphore: - can_proceed = await self.rate_limiter.acquire(estimated_tokens) + can_proceed = await self.rate_limiter.acquire(estimated_tokens, progress_callback) if not can_proceed: raise Exception("Rate limit exceeded")