""" Comprehensive Tests for Async Background Task Manager Tests the pure async background task manager after removal of ThreadPoolExecutor. Focuses on async task execution, concurrency control, and progress tracking. """ import asyncio from typing import Any from unittest.mock import AsyncMock import pytest from src.server.services.background_task_manager import ( BackgroundTaskManager, cleanup_task_manager, get_task_manager, ) class TestAsyncBackgroundTaskManager: """Test suite for async background task manager""" @pytest.fixture def task_manager(self): """Create a fresh task manager instance for each test""" return BackgroundTaskManager(max_concurrent_tasks=5) @pytest.fixture def mock_progress_callback(self): """Mock progress callback function""" return AsyncMock() @pytest.mark.asyncio async def test_task_manager_initialization(self, task_manager): """Test task manager initialization""" assert task_manager.max_concurrent_tasks == 5 assert len(task_manager.active_tasks) == 0 assert len(task_manager.task_metadata) == 0 assert task_manager._task_semaphore._value == 5 @pytest.mark.asyncio async def test_simple_async_task_execution(self, task_manager, mock_progress_callback): """Test execution of a simple async task""" async def simple_task(message: str): await asyncio.sleep(0.01) # Simulate async work return f"Task completed: {message}" task_id = await task_manager.submit_task( simple_task, ("Hello World",), progress_callback=mock_progress_callback ) # Wait for task completion await asyncio.sleep(0.05) # Check task status status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["progress"] == 100 assert status["result"] == "Task completed: Hello World" # Verify progress callback was called assert mock_progress_callback.call_count >= 1 @pytest.mark.asyncio async def test_task_with_error(self, task_manager, mock_progress_callback): """Test handling of task that raises an exception""" async def failing_task(): await asyncio.sleep(0.01) raise ValueError("Task failed intentionally") task_id = await task_manager.submit_task( failing_task, (), progress_callback=mock_progress_callback ) # Wait for task to fail await asyncio.sleep(0.05) # Check task status status = await task_manager.get_task_status(task_id) assert status["status"] == "error" assert status["progress"] == -1 assert "error" in status assert "Task failed intentionally" in status["error"] # Verify error was reported via progress callback error_call = None for call in mock_progress_callback.call_args_list: if len(call[0]) >= 2 and call[0][1].get("status") == "error": error_call = call break assert error_call is not None assert "Task failed intentionally" in error_call[0][1]["error"] @pytest.mark.asyncio async def test_concurrent_task_execution(self, task_manager): """Test execution of multiple concurrent tasks""" async def numbered_task(number: int): await asyncio.sleep(0.01) return f"Task {number} completed" # Submit 5 tasks simultaneously task_ids = [] for i in range(5): task_id = await task_manager.submit_task(numbered_task, (i,), task_id=f"task-{i}") task_ids.append(task_id) # Wait for all tasks to complete await asyncio.sleep(0.05) # Check all tasks completed successfully for i, task_id in enumerate(task_ids): status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["result"] == f"Task {i} completed" @pytest.mark.asyncio async def test_concurrency_limit(self, task_manager): """Test that concurrency is limited by semaphore""" # Use a task manager with limit of 2 limited_manager = BackgroundTaskManager(max_concurrent_tasks=2) running_tasks = [] completed_tasks = [] async def long_running_task(task_id: int): running_tasks.append(task_id) await asyncio.sleep(0.05) # Long enough to test concurrency completed_tasks.append(task_id) return f"Task {task_id} completed" # Submit 4 tasks task_ids = [] for i in range(4): task_id = await limited_manager.submit_task( long_running_task, (i,), task_id=f"concurrent-task-{i}" ) task_ids.append(task_id) # Wait a bit and check that only 2 tasks are running await asyncio.sleep(0.01) assert len(running_tasks) <= 2 # Wait for all to complete await asyncio.sleep(0.3) assert len(completed_tasks) == 4 # Clean up await limited_manager.cleanup() @pytest.mark.asyncio async def test_task_cancellation(self, task_manager): """Test cancellation of running task""" async def long_task(): try: await asyncio.sleep(1.0) # Long enough to be cancelled return "Should not complete" except asyncio.CancelledError: raise # Re-raise to properly handle cancellation task_id = await task_manager.submit_task(long_task, (), task_id="cancellable-task") # Wait a bit, then cancel await asyncio.sleep(0.01) cancelled = await task_manager.cancel_task(task_id) assert cancelled is True # Check task status await asyncio.sleep(0.01) status = await task_manager.get_task_status(task_id) assert status["status"] == "cancelled" @pytest.mark.asyncio async def test_task_not_found(self, task_manager): """Test getting status of non-existent task""" status = await task_manager.get_task_status("non-existent-task") assert status["error"] == "Task not found" @pytest.mark.asyncio async def test_cancel_non_existent_task(self, task_manager): """Test cancelling non-existent task""" cancelled = await task_manager.cancel_task("non-existent-task") assert cancelled is False @pytest.mark.asyncio async def test_progress_callback_execution(self, task_manager): """Test that progress callback is properly executed""" progress_updates = [] async def mock_progress_callback(task_id: str, update: dict[str, Any]): progress_updates.append((task_id, update)) async def simple_task(): await asyncio.sleep(0.01) return "completed" task_id = await task_manager.submit_task( simple_task, (), task_id="progress-test-task", progress_callback=mock_progress_callback ) # Wait for completion await asyncio.sleep(0.05) # Should have at least one progress update (completion) assert len(progress_updates) >= 1 # Check that task_id matches assert all(update[0] == task_id for update in progress_updates) # Check for completion update completion_updates = [ update for update in progress_updates if update[1].get("status") == "complete" ] assert len(completion_updates) >= 1 assert completion_updates[0][1]["percentage"] == 100 @pytest.mark.asyncio async def test_progress_callback_error_handling(self, task_manager): """Test that task continues even if progress callback fails""" async def failing_progress_callback(task_id: str, update: dict[str, Any]): raise Exception("Progress callback failed") async def simple_task(): await asyncio.sleep(0.01) return "Task completed despite callback failure" task_id = await task_manager.submit_task( simple_task, (), progress_callback=failing_progress_callback ) # Wait for completion await asyncio.sleep(0.05) # Task should still complete successfully status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["result"] == "Task completed despite callback failure" @pytest.mark.asyncio async def test_task_metadata_tracking(self, task_manager): """Test that task metadata is properly tracked""" async def simple_task(): await asyncio.sleep(0.01) return "result" task_id = await task_manager.submit_task(simple_task, (), task_id="metadata-test") # Check initial metadata initial_status = await task_manager.get_task_status(task_id) assert initial_status["status"] == "running" assert "created_at" in initial_status assert initial_status["progress"] == 0 # Wait for completion await asyncio.sleep(0.05) # Check final metadata final_status = await task_manager.get_task_status(task_id) assert final_status["status"] == "complete" assert final_status["progress"] == 100 assert final_status["result"] == "result" @pytest.mark.asyncio async def test_cleanup_active_tasks(self, task_manager): """Test cleanup cancels active tasks""" async def long_running_task(): try: await asyncio.sleep(1.0) return "Should not complete" except asyncio.CancelledError: raise # Submit multiple long-running tasks task_ids = [] for i in range(3): task_id = await task_manager.submit_task( long_running_task, (), task_id=f"cleanup-test-{i}" ) task_ids.append(task_id) # Verify tasks are active await asyncio.sleep(0.01) assert len(task_manager.active_tasks) == 3 # Cleanup await task_manager.cleanup() # Verify all tasks were cancelled and cleaned up assert len(task_manager.active_tasks) == 0 assert len(task_manager.task_metadata) == 0 @pytest.mark.asyncio async def test_completed_task_status_after_removal(self, task_manager): """Test getting status of completed task after it's removed from active_tasks""" async def quick_task(): return "quick result" task_id = await task_manager.submit_task(quick_task, (), task_id="quick-test") # Wait for completion and removal from active_tasks await asyncio.sleep(0.05) # Should still be able to get status from metadata status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["result"] == "quick result" def test_set_main_loop_deprecated(self, task_manager): """Test that set_main_loop is deprecated but doesn't break""" # Should not raise an exception but may log a warning import asyncio loop = asyncio.new_event_loop() task_manager.set_main_loop(loop) loop.close() class TestGlobalTaskManager: """Test the global task manager functions""" def test_get_task_manager_singleton(self): """Test that get_task_manager returns singleton""" manager1 = get_task_manager() manager2 = get_task_manager() assert manager1 is manager2 @pytest.mark.asyncio async def test_cleanup_task_manager(self): """Test cleanup of global task manager""" # Get the global manager manager = get_task_manager() assert manager is not None # Add a task to make it interesting async def test_task(): return "test" task_id = await manager.submit_task(test_task, ()) await asyncio.sleep(0.01) # Cleanup await cleanup_task_manager() # Verify it was cleaned up - getting a new one should be different new_manager = get_task_manager() assert new_manager is not manager class TestAsyncTaskPatterns: """Test various async task patterns and edge cases""" @pytest.fixture def task_manager(self): return BackgroundTaskManager(max_concurrent_tasks=3) @pytest.mark.asyncio async def test_nested_async_calls(self, task_manager): """Test tasks that make nested async calls""" async def nested_task(): async def inner_task(): await asyncio.sleep(0.01) return "inner result" result = await inner_task() return f"outer: {result}" task_id = await task_manager.submit_task(nested_task, ()) await asyncio.sleep(0.05) status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["result"] == "outer: inner result" @pytest.mark.asyncio async def test_task_with_async_context_manager(self, task_manager): """Test tasks that use async context managers""" class AsyncResource: def __init__(self): self.entered = False self.exited = False async def __aenter__(self): await asyncio.sleep(0.001) self.entered = True return self async def __aexit__(self, exc_type, exc_val, exc_tb): await asyncio.sleep(0.001) self.exited = True resource = AsyncResource() async def context_manager_task(): async with resource: await asyncio.sleep(0.01) return "context manager used" task_id = await task_manager.submit_task(context_manager_task, ()) await asyncio.sleep(0.05) status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["result"] == "context manager used" assert resource.entered assert resource.exited @pytest.mark.asyncio async def test_task_cancellation_propagation(self, task_manager): """Test that cancellation properly propagates through nested calls""" cancelled_flags = [] async def cancellable_inner(): try: await asyncio.sleep(1.0) return "should not complete" except asyncio.CancelledError: cancelled_flags.append("inner") raise async def cancellable_outer(): try: result = await cancellable_inner() return f"outer: {result}" except asyncio.CancelledError: cancelled_flags.append("outer") raise task_id = await task_manager.submit_task(cancellable_outer, ()) await asyncio.sleep(0.01) # Cancel the task cancelled = await task_manager.cancel_task(task_id) assert cancelled await asyncio.sleep(0.01) # Both inner and outer should have been cancelled assert "inner" in cancelled_flags assert "outer" in cancelled_flags @pytest.mark.asyncio async def test_high_concurrency_stress_test(self, task_manager): """Stress test with many concurrent tasks""" async def stress_task(task_num: int): await asyncio.sleep(0.001 * (task_num % 10)) # Vary sleep time return f"stress-{task_num}" # Submit many tasks task_ids = [] num_tasks = 20 for i in range(num_tasks): task_id = await task_manager.submit_task(stress_task, (i,), task_id=f"stress-{i}") task_ids.append(task_id) # Wait for all to complete await asyncio.sleep(0.5) # Verify all completed successfully for i, task_id in enumerate(task_ids): status = await task_manager.get_task_status(task_id) assert status["status"] == "complete" assert status["result"] == f"stress-{i}" @pytest.mark.asyncio async def test_task_execution_order_with_semaphore(self, task_manager): """Test that semaphore properly controls execution order""" # Use manager with limit of 2 limited_manager = BackgroundTaskManager(max_concurrent_tasks=2) execution_order = [] async def ordered_task(task_id: int): execution_order.append(f"start-{task_id}") await asyncio.sleep(0.02) execution_order.append(f"end-{task_id}") return task_id # Submit 4 tasks task_ids = [] for i in range(4): task_id = await limited_manager.submit_task(ordered_task, (i,), task_id=f"order-{i}") task_ids.append(task_id) # Wait for completion await asyncio.sleep(0.2) # Verify execution pattern - should see at most 2 concurrent executions starts_before_ends = 0 for i, event in enumerate(execution_order): if event.startswith("start-"): # Count how many starts we've seen before the first end starts_seen = sum(1 for e in execution_order[: i + 1] if e.startswith("start-")) ends_seen = sum(1 for e in execution_order[: i + 1] if e.startswith("end-")) concurrent = starts_seen - ends_seen assert concurrent <= 2 # Should never exceed semaphore limit await limited_manager.cleanup()