Archon/python/src/server/services/provider_discovery_service.py
John Fitzpatrick d4e80a945a fix: Change Ollama default URL to host.docker.internal for Docker compatibility
- Changed default Ollama URL from localhost:11434 to host.docker.internal:11434
- This allows Docker containers to connect to Ollama running on the host machine
- Updated in backend services, frontend components, migration scripts, and documentation
- Most users run Archon in Docker but Ollama as a local binary, making this a better default
2025-09-20 13:36:33 -07:00

506 lines
22 KiB
Python

"""
Provider Discovery Service
Discovers available models, checks provider health, and provides model specifications
for OpenAI, Google Gemini, Ollama, and Anthropic providers.
"""
import time
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse
import aiohttp
import openai
from ..config.logfire_config import get_logger
from .credential_service import credential_service
logger = get_logger(__name__)
# Provider capabilities and model specifications cache
_provider_cache: dict[str, tuple[Any, float]] = {}
_CACHE_TTL_SECONDS = 300 # 5 minutes
# Default Ollama instance URL (configurable via environment/settings)
DEFAULT_OLLAMA_URL = "http://host.docker.internal:11434"
# Model pattern detection for dynamic capabilities (no hardcoded model names)
CHAT_MODEL_PATTERNS = ["llama", "qwen", "mistral", "codellama", "phi", "gemma", "vicuna", "orca"]
EMBEDDING_MODEL_PATTERNS = ["embed", "embedding"]
VISION_MODEL_PATTERNS = ["vision", "llava", "moondream"]
# Context window estimates by model family (heuristics, not hardcoded requirements)
MODEL_CONTEXT_WINDOWS = {
"llama3": 8192,
"qwen": 32768,
"mistral": 8192,
"codellama": 16384,
"phi": 4096,
"gemma": 8192,
}
# Embedding dimensions for common models (heuristics)
EMBEDDING_DIMENSIONS = {
"nomic-embed": 768,
"mxbai-embed": 1024,
"all-minilm": 384,
}
@dataclass
class ModelSpec:
"""Model specification with capabilities and constraints."""
name: str
provider: str
context_window: int
supports_tools: bool = False
supports_vision: bool = False
supports_embeddings: bool = False
embedding_dimensions: int | None = None
pricing_input: float | None = None # Per million tokens
pricing_output: float | None = None # Per million tokens
description: str = ""
aliases: list[str] = None
def __post_init__(self):
if self.aliases is None:
self.aliases = []
@dataclass
class ProviderStatus:
"""Provider health and connectivity status."""
provider: str
is_available: bool
response_time_ms: float | None = None
error_message: str | None = None
models_available: int = 0
base_url: str | None = None
last_checked: float | None = None
class ProviderDiscoveryService:
"""Service for discovering models and checking provider health."""
def __init__(self):
self._session: aiohttp.ClientSession | None = None
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create HTTP session for provider requests."""
if self._session is None:
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def close(self):
"""Close HTTP session."""
if self._session:
await self._session.close()
self._session = None
def _get_cached_result(self, cache_key: str) -> Any | None:
"""Get cached result if not expired."""
if cache_key in _provider_cache:
result, timestamp = _provider_cache[cache_key]
if time.time() - timestamp < _CACHE_TTL_SECONDS:
return result
else:
del _provider_cache[cache_key]
return None
def _cache_result(self, cache_key: str, result: Any) -> None:
"""Cache result with current timestamp."""
_provider_cache[cache_key] = (result, time.time())
async def _test_tool_support(self, model_name: str, api_url: str) -> bool:
"""
Test if a model supports function/tool calling by making an actual API call.
Args:
model_name: Name of the model to test
api_url: Base URL of the Ollama instance
Returns:
True if tool calling is supported, False otherwise
"""
try:
import openai
# Use OpenAI-compatible client for function calling test
client = openai.AsyncOpenAI(
base_url=f"{api_url}/v1",
api_key="ollama" # Dummy API key for Ollama
)
# Define a simple test function
test_function = {
"name": "test_function",
"description": "A test function",
"parameters": {
"type": "object",
"properties": {
"test_param": {
"type": "string",
"description": "A test parameter"
}
},
"required": ["test_param"]
}
}
# Try to make a function calling request
response = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "Call the test function with parameter 'hello'"}],
tools=[{"type": "function", "function": test_function}],
max_tokens=50,
timeout=5 # Short timeout for quick testing
)
# Check if the model attempted to use the function
if response.choices and len(response.choices) > 0:
choice = response.choices[0]
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
logger.info(f"Model {model_name} supports tool calling")
return True
return False
except Exception as e:
logger.debug(f"Tool support test failed for {model_name}: {e}")
# Fall back to name-based heuristics for known models
return any(pattern in model_name.lower()
for pattern in CHAT_MODEL_PATTERNS)
finally:
if 'client' in locals():
await client.close()
async def discover_openai_models(self, api_key: str) -> list[ModelSpec]:
"""Discover available OpenAI models."""
cache_key = f"openai_models_{hash(api_key)}"
cached = self._get_cached_result(cache_key)
if cached:
return cached
models = []
try:
client = openai.AsyncOpenAI(api_key=api_key)
response = await client.models.list()
# OpenAI model specifications
model_specs = {
"gpt-4o": ModelSpec("gpt-4o", "openai", 128000, True, True, False, None, 2.50, 10.00, "Most capable GPT-4 model with vision"),
"gpt-4o-mini": ModelSpec("gpt-4o-mini", "openai", 128000, True, True, False, None, 0.15, 0.60, "Affordable GPT-4 model"),
"gpt-4-turbo": ModelSpec("gpt-4-turbo", "openai", 128000, True, True, False, None, 10.00, 30.00, "GPT-4 Turbo with vision"),
"gpt-3.5-turbo": ModelSpec("gpt-3.5-turbo", "openai", 16385, True, False, False, None, 0.50, 1.50, "Fast and efficient model"),
"text-embedding-3-large": ModelSpec("text-embedding-3-large", "openai", 8191, False, False, True, 3072, 0.13, 0, "High-quality embedding model"),
"text-embedding-3-small": ModelSpec("text-embedding-3-small", "openai", 8191, False, False, True, 1536, 0.02, 0, "Efficient embedding model"),
"text-embedding-ada-002": ModelSpec("text-embedding-ada-002", "openai", 8191, False, False, True, 1536, 0.10, 0, "Legacy embedding model"),
}
for model in response.data:
if model.id in model_specs:
models.append(model_specs[model.id])
else:
# Create basic spec for unknown models
models.append(ModelSpec(
name=model.id,
provider="openai",
context_window=4096, # Default assumption
description=f"OpenAI model {model.id}"
))
self._cache_result(cache_key, models)
logger.info(f"Discovered {len(models)} OpenAI models")
except Exception as e:
logger.error(f"Error discovering OpenAI models: {e}")
return models
async def discover_google_models(self, api_key: str) -> list[ModelSpec]:
"""Discover available Google Gemini models."""
cache_key = f"google_models_{hash(api_key)}"
cached = self._get_cached_result(cache_key)
if cached:
return cached
models = []
try:
# Google Gemini model specifications
model_specs = [
ModelSpec("gemini-1.5-pro", "google", 2097152, True, True, False, None, 1.25, 5.00, "Advanced reasoning and multimodal capabilities"),
ModelSpec("gemini-1.5-flash", "google", 1048576, True, True, False, None, 0.075, 0.30, "Fast and versatile performance"),
ModelSpec("gemini-1.0-pro", "google", 30720, True, False, False, None, 0.50, 1.50, "Efficient model for text tasks"),
ModelSpec("text-embedding-004", "google", 2048, False, False, True, 768, 0.00, 0, "Google's latest embedding model"),
]
# Test connectivity with a simple request
session = await self._get_session()
base_url = "https://generativelanguage.googleapis.com/v1beta/models"
headers = {"Authorization": f"Bearer {api_key}"}
async with session.get(f"{base_url}?key={api_key}", headers=headers) as response:
if response.status == 200:
models = model_specs
self._cache_result(cache_key, models)
logger.info(f"Discovered {len(models)} Google models")
else:
logger.warning(f"Google API returned status {response.status}")
except Exception as e:
logger.error(f"Error discovering Google models: {e}")
return models
async def discover_ollama_models(self, base_urls: list[str]) -> list[ModelSpec]:
"""Discover available Ollama models from multiple instances."""
all_models = []
for base_url in base_urls:
cache_key = f"ollama_models_{base_url}"
cached = self._get_cached_result(cache_key)
if cached:
all_models.extend(cached)
continue
try:
# Clean up URL - remove /v1 suffix if present for raw Ollama API
parsed = urlparse(base_url)
if parsed.path.endswith('/v1'):
api_url = base_url.replace('/v1', '')
else:
api_url = base_url
session = await self._get_session()
# Get installed models
async with session.get(f"{api_url}/api/tags") as response:
if response.status == 200:
data = await response.json()
models = []
for model_info in data.get("models", []):
model_name = model_info.get("name", "").split(':')[0] # Remove tag
# Determine model capabilities based on testing and name patterns
# Test for function calling capabilities via actual API calls
supports_tools = await self._test_tool_support(model_name, api_url)
# Vision support is typically indicated by name patterns (reliable indicator)
supports_vision = any(pattern in model_name.lower() for pattern in VISION_MODEL_PATTERNS)
# Embedding support is typically indicated by name patterns (reliable indicator)
supports_embeddings = any(pattern in model_name.lower() for pattern in EMBEDDING_MODEL_PATTERNS)
# Estimate context window based on model family
context_window = 4096 # Default
for family, window_size in MODEL_CONTEXT_WINDOWS.items():
if family in model_name.lower():
context_window = window_size
break
# Set embedding dimensions for known embedding models
embedding_dims = None
for model_pattern, dims in EMBEDDING_DIMENSIONS.items():
if model_pattern in model_name.lower():
embedding_dims = dims
break
spec = ModelSpec(
name=model_info.get("name", model_name),
provider="ollama",
context_window=context_window,
supports_tools=supports_tools,
supports_vision=supports_vision,
supports_embeddings=supports_embeddings,
embedding_dimensions=embedding_dims,
description=f"Ollama model on {base_url}",
aliases=[model_name] if ':' in model_info.get("name", "") else []
)
models.append(spec)
self._cache_result(cache_key, models)
all_models.extend(models)
logger.info(f"Discovered {len(models)} Ollama models from {base_url}")
else:
logger.warning(f"Ollama instance at {base_url} returned status {response.status}")
except Exception as e:
logger.error(f"Error discovering Ollama models from {base_url}: {e}")
return all_models
async def discover_anthropic_models(self, api_key: str) -> list[ModelSpec]:
"""Discover available Anthropic Claude models."""
cache_key = f"anthropic_models_{hash(api_key)}"
cached = self._get_cached_result(cache_key)
if cached:
return cached
models = []
try:
# Anthropic Claude model specifications
model_specs = [
ModelSpec("claude-3-5-sonnet-20241022", "anthropic", 200000, True, True, False, None, 3.00, 15.00, "Most intelligent Claude model"),
ModelSpec("claude-3-5-haiku-20241022", "anthropic", 200000, True, False, False, None, 0.25, 1.25, "Fast and cost-effective Claude model"),
ModelSpec("claude-3-opus-20240229", "anthropic", 200000, True, True, False, None, 15.00, 75.00, "Powerful model for complex tasks"),
ModelSpec("claude-3-sonnet-20240229", "anthropic", 200000, True, True, False, None, 3.00, 15.00, "Balanced performance and cost"),
ModelSpec("claude-3-haiku-20240307", "anthropic", 200000, True, False, False, None, 0.25, 1.25, "Fast responses and cost-effective"),
]
# Test connectivity - Anthropic doesn't have a models list endpoint,
# so we'll just return the known models if API key is provided
if api_key:
models = model_specs
self._cache_result(cache_key, models)
logger.info(f"Discovered {len(models)} Anthropic models")
except Exception as e:
logger.error(f"Error discovering Anthropic models: {e}")
return models
async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus:
"""Check health and connectivity status of a provider."""
start_time = time.time()
try:
if provider == "openai":
api_key = config.get("api_key")
if not api_key:
return ProviderStatus(provider, False, None, "API key not configured")
client = openai.AsyncOpenAI(api_key=api_key)
models = await client.models.list()
response_time = (time.time() - start_time) * 1000
return ProviderStatus(
provider="openai",
is_available=True,
response_time_ms=response_time,
models_available=len(models.data),
last_checked=time.time()
)
elif provider == "google":
api_key = config.get("api_key")
if not api_key:
return ProviderStatus(provider, False, None, "API key not configured")
session = await self._get_session()
base_url = "https://generativelanguage.googleapis.com/v1beta/models"
async with session.get(f"{base_url}?key={api_key}") as response:
response_time = (time.time() - start_time) * 1000
if response.status == 200:
data = await response.json()
return ProviderStatus(
provider="google",
is_available=True,
response_time_ms=response_time,
models_available=len(data.get("models", [])),
base_url=base_url,
last_checked=time.time()
)
else:
return ProviderStatus(provider, False, response_time, f"HTTP {response.status}")
elif provider == "ollama":
base_urls = config.get("base_urls", [config.get("base_url", DEFAULT_OLLAMA_URL)])
if isinstance(base_urls, str):
base_urls = [base_urls]
# Check the first available Ollama instance
for base_url in base_urls:
try:
# Clean up URL for raw Ollama API
parsed = urlparse(base_url)
if parsed.path.endswith('/v1'):
api_url = base_url.replace('/v1', '')
else:
api_url = base_url
session = await self._get_session()
async with session.get(f"{api_url}/api/tags") as response:
response_time = (time.time() - start_time) * 1000
if response.status == 200:
data = await response.json()
return ProviderStatus(
provider="ollama",
is_available=True,
response_time_ms=response_time,
models_available=len(data.get("models", [])),
base_url=api_url,
last_checked=time.time()
)
except Exception:
continue # Try next URL
return ProviderStatus(provider, False, None, "No Ollama instances available")
elif provider == "anthropic":
api_key = config.get("api_key")
if not api_key:
return ProviderStatus(provider, False, None, "API key not configured")
# Anthropic doesn't have a health check endpoint, so we'll assume it's available
# if API key is provided. In a real implementation, you might want to make a
# small test request to verify the key is valid.
response_time = (time.time() - start_time) * 1000
return ProviderStatus(
provider="anthropic",
is_available=True,
response_time_ms=response_time,
models_available=5, # Known model count
last_checked=time.time()
)
else:
return ProviderStatus(provider, False, None, f"Unknown provider: {provider}")
except Exception as e:
response_time = (time.time() - start_time) * 1000
return ProviderStatus(
provider=provider,
is_available=False,
response_time_ms=response_time,
error_message=str(e),
last_checked=time.time()
)
async def get_all_available_models(self) -> dict[str, list[ModelSpec]]:
"""Get all available models from all configured providers."""
providers = {}
try:
# Get provider configurations
rag_settings = await credential_service.get_credentials_by_category("rag_strategy")
# OpenAI
openai_key = await credential_service.get_credential("OPENAI_API_KEY")
if openai_key:
providers["openai"] = await self.discover_openai_models(openai_key)
# Google
google_key = await credential_service.get_credential("GOOGLE_API_KEY")
if google_key:
providers["google"] = await self.discover_google_models(google_key)
# Ollama
ollama_urls = [rag_settings.get("LLM_BASE_URL", DEFAULT_OLLAMA_URL)]
providers["ollama"] = await self.discover_ollama_models(ollama_urls)
# Anthropic
anthropic_key = await credential_service.get_credential("ANTHROPIC_API_KEY")
if anthropic_key:
providers["anthropic"] = await self.discover_anthropic_models(anthropic_key)
except Exception as e:
logger.error(f"Error getting all available models: {e}")
return providers
# Global instance
provider_discovery_service = ProviderDiscoveryService()