Feat:Openrouter/Anthropic/grok-support (#231)

* Add Anthropic and Grok provider support

* feat: Add crucial GPT-5 and reasoning model support for OpenRouter

- Add requires_max_completion_tokens() function for GPT-5, o1, o3, Grok-3 series
- Add prepare_chat_completion_params() for reasoning model compatibility
- Implement max_tokens → max_completion_tokens conversion for reasoning models
- Add temperature handling for reasoning models (must be 1.0 default)
- Enhanced provider validation and API key security in provider endpoints
- Streamlined retry logic (3→2 attempts) for faster issue detection
- Add failure tracking and circuit breaker analysis for debugging
- Support OpenRouter format detection (openai/gpt-5-nano, openai/o1-mini)
- Improved Grok provider empty response handling with structured fallbacks
- Enhanced contextual embedding with provider-aware model selection

Core provider functionality:
- OpenRouter, Grok, Anthropic provider support with full embedding integration
- Provider-specific model defaults and validation
- Secure API connectivity testing endpoints
- Provider context passing for code generation workflows

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fully working model providers, addressing securtiy and code related concerns, throughly hardening our code

* added multiprovider support, embeddings model support, cleaned the pr, need to fix health check, asnyico tasks errors, and contextual embeddings error

* fixed contextual embeddings issue

* - Added inspect-aware shutdown handling so get_llm_client always closes the underlying AsyncOpenAI / httpx.AsyncClient while the loop is   still alive, with defensive logging if shutdown happens late (python/src/server/services/llm_provider_service.py:14, python/src/server/    services/llm_provider_service.py:520).

* - Restructured get_llm_client so client creation and usage live in separate try/finally blocks; fallback clients now close without         logging spurious Error creating LLM client when downstream code raises (python/src/server/services/llm_provider_service.py:335-556).    - Close logic now sanitizes provider names consistently and awaits whichever aclose/close coroutine the SDK exposes, keeping the loop      shut down cleanly (python/src/server/services/llm_provider_service.py:530-559).                                                                                                                                                                                                       Robust JSON Parsing                                                                                                                                                                                                                                                                   - Added _extract_json_payload to strip code fences / extra text returned by Ollama before json.loads runs, averting the markdown-induced   decode errors you saw in logs (python/src/server/services/storage/code_storage_service.py:40-63).                                          - Swapped the direct parse call for the sanitized payload and emit a debug preview when cleanup alters the content (python/src/server/     services/storage/code_storage_service.py:858-864).

* added provider connection support

* added provider api key not being configured warning

* Updated get_llm_client so missing OpenAI keys automatically fall back to Ollama (matching existing tests) and so unsupported providers     still raise the legacy ValueError the suite expects. The fallback now reuses _get_optimal_ollama_instance and rethrows ValueError(OpenAI  API key not found and Ollama fallback failed) when it cant connect.  Adjusted test_code_extraction_source_id.py to accept the new optional argument on the mocked extractor (and confirm its None when         present).

* Resolved a few needed code rabbit suggestion   - Updated the knowledge API key validation to call create_embedding with the provider argument and removed the hard-coded OpenAI fallback  (python/src/server/api_routes/knowledge_api.py).                                                                                           - Broadened embedding provider detection so prefixed OpenRouter/OpenAI model names route through the correct client (python/src/server/    services/embeddings/embedding_service.py, python/src/server/services/llm_provider_service.py).                                             - Removed the duplicate helper definitions from llm_provider_service.py, eliminating the stray docstring that was causing the import-time  syntax error.

* updated via code rabbit PR review, code rabbit in my IDE found no issues and no nitpicks with the updates! what was done:    Credential service now persists the provider under the uppercase key LLM_PROVIDER, matching the read path (no new EMBEDDING_PROVIDER     usage introduced).                                                                                                                          Embedding batch creation stops inserting blank strings, logging failures and skipping invalid items before they ever hit the provider    (python/src/server/services/embeddings/embedding_service.py).                                                                               Contextual embedding prompts use real newline characters everywhereboth when constructing the batch prompt and when parsing the         models response (python/src/server/services/embeddings/contextual_embedding_service.py).                                                   Embedding provider routing already recognizes OpenRouter-prefixed OpenAI models via is_openai_embedding_model; no further change needed  there.                                                                                                                                      Embedding insertion now skips unsupported vector dimensions instead of forcing them into the 1536-column, and the backoff loop uses      await asyncio.sleep so we no longer block the event loop (python/src/server/services/storage/code_storage_service.py).                      RAG settings props were extended to include LLM_INSTANCE_NAME and OLLAMA_EMBEDDING_INSTANCE_NAME, and the debug log no longer prints     API-key prefixes (the rest of the TanStack refactor/EMBEDDING_PROVIDER support remains deferred).

* test fix

* enhanced Openrouters parsing logic to automatically detect reasoning models and parse regardless of json output or not. this commit creates a robust way for archons parsing to work throughly with openrouter automatically, regardless of the model youre using, to ensure proper functionality with out breaking any generation capabilities!

---------

Co-authored-by: Chillbruhhh <joshchesser97@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Josh 2025-09-22 02:36:30 -05:00 committed by GitHub
parent 4c910c1471
commit 394ac1befa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2090 additions and 450 deletions

View File

@ -9,6 +9,103 @@ import { credentialsService } from '../../services/credentialsService';
import OllamaModelDiscoveryModal from './OllamaModelDiscoveryModal';
import OllamaModelSelectionModal from './OllamaModelSelectionModal';
type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter';
interface ProviderModels {
chatModel: string;
embeddingModel: string;
}
type ProviderModelMap = Record<ProviderKey, ProviderModels>;
// Provider model persistence helpers
const PROVIDER_MODELS_KEY = 'archon_provider_models';
const getDefaultModels = (provider: ProviderKey): ProviderModels => {
const chatDefaults: Record<ProviderKey, string> = {
openai: 'gpt-4o-mini',
anthropic: 'claude-3-5-sonnet-20241022',
google: 'gemini-1.5-flash',
grok: 'grok-3-mini', // Updated to use grok-3-mini as default
openrouter: 'openai/gpt-4o-mini',
ollama: 'llama3:8b'
};
const embeddingDefaults: Record<ProviderKey, string> = {
openai: 'text-embedding-3-small',
anthropic: 'text-embedding-3-small', // Fallback to OpenAI
google: 'text-embedding-004',
grok: 'text-embedding-3-small', // Fallback to OpenAI
openrouter: 'text-embedding-3-small',
ollama: 'nomic-embed-text'
};
return {
chatModel: chatDefaults[provider],
embeddingModel: embeddingDefaults[provider]
};
};
const saveProviderModels = (providerModels: ProviderModelMap): void => {
try {
localStorage.setItem(PROVIDER_MODELS_KEY, JSON.stringify(providerModels));
} catch (error) {
console.error('Failed to save provider models:', error);
}
};
const loadProviderModels = (): ProviderModelMap => {
try {
const saved = localStorage.getItem(PROVIDER_MODELS_KEY);
if (saved) {
return JSON.parse(saved);
}
} catch (error) {
console.error('Failed to load provider models:', error);
}
// Return defaults for all providers if nothing saved
const providers: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok'];
const defaultModels: ProviderModelMap = {} as ProviderModelMap;
providers.forEach(provider => {
defaultModels[provider] = getDefaultModels(provider);
});
return defaultModels;
};
// Static color styles mapping (prevents Tailwind JIT purging)
const colorStyles: Record<ProviderKey, string> = {
openai: 'border-green-500 bg-green-500/10',
google: 'border-blue-500 bg-blue-500/10',
openrouter: 'border-cyan-500 bg-cyan-500/10',
ollama: 'border-purple-500 bg-purple-500/10',
anthropic: 'border-orange-500 bg-orange-500/10',
grok: 'border-yellow-500 bg-yellow-500/10',
};
const providerAlertStyles: Record<ProviderKey, string> = {
openai: 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800 text-green-800 dark:text-green-300',
google: 'bg-blue-50 dark:bg-blue-900/20 border-blue-200 dark:border-blue-800 text-blue-800 dark:text-blue-300',
openrouter: 'bg-cyan-50 dark:bg-cyan-900/20 border-cyan-200 dark:border-cyan-800 text-cyan-800 dark:text-cyan-300',
ollama: 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800 text-purple-800 dark:text-purple-300',
anthropic: 'bg-orange-50 dark:bg-orange-900/20 border-orange-200 dark:border-orange-800 text-orange-800 dark:text-orange-300',
grok: 'bg-yellow-50 dark:bg-yellow-900/20 border-yellow-200 dark:border-yellow-800 text-yellow-800 dark:text-yellow-300',
};
const providerAlertMessages: Record<ProviderKey, string> = {
openai: 'Configure your OpenAI API key in the credentials section to use GPT models.',
google: 'Configure your Google API key in the credentials section to use Gemini models.',
openrouter: 'Configure your OpenRouter API key in the credentials section to use models.',
ollama: 'Configure your Ollama instances in this panel to connect local models.',
anthropic: 'Configure your Anthropic API key in the credentials section to use Claude models.',
grok: 'Configure your Grok API key in the credentials section to use Grok models.',
};
const isProviderKey = (value: unknown): value is ProviderKey =>
typeof value === 'string' && ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok'].includes(value);
interface RAGSettingsProps {
ragSettings: {
MODEL_CHOICE: string;
@ -19,8 +116,10 @@ interface RAGSettingsProps {
USE_RERANKING: boolean;
LLM_PROVIDER?: string;
LLM_BASE_URL?: string;
LLM_INSTANCE_NAME?: string;
EMBEDDING_MODEL?: string;
OLLAMA_EMBEDDING_URL?: string;
OLLAMA_EMBEDDING_INSTANCE_NAME?: string;
// Crawling Performance Settings
CRAWL_BATCH_SIZE?: number;
CRAWL_MAX_CONCURRENT?: number;
@ -57,7 +156,10 @@ export const RAGSettings = ({
// Model selection modals state
const [showLLMModelSelectionModal, setShowLLMModelSelectionModal] = useState(false);
const [showEmbeddingModelSelectionModal, setShowEmbeddingModelSelectionModal] = useState(false);
// Provider-specific model persistence state
const [providerModels, setProviderModels] = useState<ProviderModelMap>(() => loadProviderModels());
// Instance configurations
const [llmInstanceConfig, setLLMInstanceConfig] = useState({
name: '',
@ -113,6 +215,25 @@ export const RAGSettings = ({
}
}, [ragSettings.OLLAMA_EMBEDDING_URL, ragSettings.OLLAMA_EMBEDDING_INSTANCE_NAME]);
// Provider model persistence effects
useEffect(() => {
// Update provider models when current models change
const currentProvider = ragSettings.LLM_PROVIDER as ProviderKey;
if (currentProvider && ragSettings.MODEL_CHOICE && ragSettings.EMBEDDING_MODEL) {
setProviderModels(prev => {
const updated = {
...prev,
[currentProvider]: {
chatModel: ragSettings.MODEL_CHOICE,
embeddingModel: ragSettings.EMBEDDING_MODEL
}
};
saveProviderModels(updated);
return updated;
});
}
}, [ragSettings.MODEL_CHOICE, ragSettings.EMBEDDING_MODEL, ragSettings.LLM_PROVIDER]);
// Load API credentials for status checking
useEffect(() => {
const loadApiCredentials = async () => {
@ -197,58 +318,27 @@ export const RAGSettings = ({
}>({});
// Test connection to external providers
const testProviderConnection = async (provider: string, apiKey: string): Promise<boolean> => {
const testProviderConnection = async (provider: string): Promise<boolean> => {
setProviderConnectionStatus(prev => ({
...prev,
[provider]: { ...prev[provider], checking: true }
}));
try {
switch (provider) {
case 'openai':
// Test OpenAI connection with a simple completion request
const openaiResponse = await fetch('https://api.openai.com/v1/models', {
method: 'GET',
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json'
}
});
if (openaiResponse.ok) {
setProviderConnectionStatus(prev => ({
...prev,
openai: { connected: true, checking: false, lastChecked: new Date() }
}));
return true;
} else {
throw new Error(`OpenAI API returned ${openaiResponse.status}`);
}
// Use server-side API endpoint for secure connectivity testing
const response = await fetch(`/api/providers/${provider}/status`);
const result = await response.json();
case 'google':
// Test Google Gemini connection
const googleResponse = await fetch(`https://generativelanguage.googleapis.com/v1/models?key=${apiKey}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
}
});
if (googleResponse.ok) {
setProviderConnectionStatus(prev => ({
...prev,
google: { connected: true, checking: false, lastChecked: new Date() }
}));
return true;
} else {
throw new Error(`Google API returned ${googleResponse.status}`);
}
const isConnected = result.ok && result.reason === 'connected';
default:
return false;
}
setProviderConnectionStatus(prev => ({
...prev,
[provider]: { connected: isConnected, checking: false, lastChecked: new Date() }
}));
return isConnected;
} catch (error) {
console.error(`Failed to test ${provider} connection:`, error);
console.error(`Error testing ${provider} connection:`, error);
setProviderConnectionStatus(prev => ({
...prev,
[provider]: { connected: false, checking: false, lastChecked: new Date() }
@ -260,37 +350,27 @@ export const RAGSettings = ({
// Test provider connections when API credentials change
useEffect(() => {
const testConnections = async () => {
const providers = ['openai', 'google'];
// Test all supported providers
const providers = ['openai', 'google', 'anthropic', 'openrouter', 'grok'];
for (const provider of providers) {
const keyName = provider === 'openai' ? 'OPENAI_API_KEY' : 'GOOGLE_API_KEY';
const apiKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === keyName);
const keyValue = apiKey ? apiCredentials[apiKey] : undefined;
if (keyValue && keyValue.trim().length > 0) {
// Don't test if we've already checked recently (within last 30 seconds)
const lastChecked = providerConnectionStatus[provider]?.lastChecked;
const now = new Date();
const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity;
if (timeSinceLastCheck > 30000) { // 30 seconds
console.log(`🔄 Testing ${provider} connection...`);
await testProviderConnection(provider, keyValue);
}
} else {
// No API key, mark as disconnected
setProviderConnectionStatus(prev => ({
...prev,
[provider]: { connected: false, checking: false, lastChecked: new Date() }
}));
// Don't test if we've already checked recently (within last 30 seconds)
const lastChecked = providerConnectionStatus[provider]?.lastChecked;
const now = new Date();
const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity;
if (timeSinceLastCheck > 30000) { // 30 seconds
console.log(`🔄 Testing ${provider} connection...`);
await testProviderConnection(provider);
}
}
};
// Only test if we have credentials loaded
if (Object.keys(apiCredentials).length > 0) {
testConnections();
}
// Test connections periodically (every 60 seconds)
testConnections();
const interval = setInterval(testConnections, 60000);
return () => clearInterval(interval);
}, [apiCredentials]); // Test when credentials change
// Ref to track if initial test has been run (will be used after function definitions)
@ -662,24 +742,41 @@ export const RAGSettings = ({
if (llmStatus.online || embeddingStatus.online) return 'partial';
return 'missing';
case 'anthropic':
// Check if Anthropic API key is configured (case insensitive)
const anthropicKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'ANTHROPIC_API_KEY');
const hasAnthropicKey = anthropicKey && apiCredentials[anthropicKey] && apiCredentials[anthropicKey].trim().length > 0;
return hasAnthropicKey ? 'configured' : 'missing';
// Use server-side connection status
const anthropicConnected = providerConnectionStatus['anthropic']?.connected || false;
const anthropicChecking = providerConnectionStatus['anthropic']?.checking || false;
if (anthropicChecking) return 'partial';
return anthropicConnected ? 'configured' : 'missing';
case 'grok':
// Check if Grok API key is configured (case insensitive)
const grokKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'GROK_API_KEY');
const hasGrokKey = grokKey && apiCredentials[grokKey] && apiCredentials[grokKey].trim().length > 0;
return hasGrokKey ? 'configured' : 'missing';
// Use server-side connection status
const grokConnected = providerConnectionStatus['grok']?.connected || false;
const grokChecking = providerConnectionStatus['grok']?.checking || false;
if (grokChecking) return 'partial';
return grokConnected ? 'configured' : 'missing';
case 'openrouter':
// Check if OpenRouter API key is configured (case insensitive)
const openRouterKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'OPENROUTER_API_KEY');
const hasOpenRouterKey = openRouterKey && apiCredentials[openRouterKey] && apiCredentials[openRouterKey].trim().length > 0;
return hasOpenRouterKey ? 'configured' : 'missing';
// Use server-side connection status
const openRouterConnected = providerConnectionStatus['openrouter']?.connected || false;
const openRouterChecking = providerConnectionStatus['openrouter']?.checking || false;
if (openRouterChecking) return 'partial';
return openRouterConnected ? 'configured' : 'missing';
default:
return 'missing';
}
};;
};
const selectedProviderKey = isProviderKey(ragSettings.LLM_PROVIDER)
? (ragSettings.LLM_PROVIDER as ProviderKey)
: undefined;
const selectedProviderStatus = selectedProviderKey ? getProviderStatus(selectedProviderKey) : undefined;
const shouldShowProviderAlert = Boolean(
selectedProviderKey && selectedProviderStatus === 'missing'
);
const providerAlertClassName = shouldShowProviderAlert && selectedProviderKey
? providerAlertStyles[selectedProviderKey]
: '';
const providerAlertMessage = shouldShowProviderAlert && selectedProviderKey
? providerAlertMessages[selectedProviderKey]
: '';
// Test Ollama connectivity when Settings page loads (scenario 4: page load)
// This useEffect is placed after function definitions to ensure access to manualTestConnection
@ -750,55 +847,32 @@ export const RAGSettings = ({
{[
{ key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' },
{ key: 'google', name: 'Google', logo: '/img/google-logo.svg', color: 'blue' },
{ key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' },
{ key: 'ollama', name: 'Ollama', logo: '/img/Ollama.png', color: 'purple' },
{ key: 'anthropic', name: 'Anthropic', logo: '/img/claude-logo.svg', color: 'orange' },
{ key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' },
{ key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' }
{ key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' }
].map(provider => (
<button
key={provider.key}
type="button"
onClick={() => {
// Get saved models for this provider, or use defaults
const providerKey = provider.key as ProviderKey;
const savedModels = providerModels[providerKey] || getDefaultModels(providerKey);
const updatedSettings = {
...ragSettings,
LLM_PROVIDER: provider.key
LLM_PROVIDER: providerKey,
MODEL_CHOICE: savedModels.chatModel,
EMBEDDING_MODEL: savedModels.embeddingModel
};
// Set models to provider-appropriate defaults when switching providers
// This ensures both LLM and embedding models switch when provider changes
const getDefaultChatModel = (provider: string): string => {
switch (provider) {
case 'openai': return 'gpt-4o-mini';
case 'anthropic': return 'claude-3-5-sonnet-20241022';
case 'google': return 'gemini-1.5-flash';
case 'grok': return 'grok-2-latest';
case 'ollama': return '';
case 'openrouter': return 'anthropic/claude-3.5-sonnet';
default: return 'gpt-4o-mini';
}
};
const getDefaultEmbeddingModel = (provider: string): string => {
switch (provider) {
case 'openai': return 'text-embedding-3-small';
case 'google': return 'text-embedding-004';
case 'ollama': return '';
case 'openrouter': return 'text-embedding-3-small';
case 'anthropic':
case 'grok':
default: return 'text-embedding-3-small';
}
};
updatedSettings.MODEL_CHOICE = getDefaultChatModel(provider.key);
updatedSettings.EMBEDDING_MODEL = getDefaultEmbeddingModel(provider.key);
setRagSettings(updatedSettings);
}}
className={`
relative p-3 rounded-lg border-2 transition-all duration-200 text-center
${ragSettings.LLM_PROVIDER === provider.key
? `border-${provider.color}-500 bg-${provider.color}-500/10 shadow-[0_0_15px_rgba(34,197,94,0.3)]`
? `${colorStyles[provider.key as ProviderKey]} shadow-[0_0_15px_rgba(34,197,94,0.3)]`
: 'border-gray-300 dark:border-gray-600 hover:border-gray-400 dark:hover:border-gray-500'
}
hover:scale-105 active:scale-95
@ -813,8 +887,8 @@ export const RAGSettings = ({
: ''
}`}
/>
<div className={`text-sm font-medium text-gray-700 dark:text-gray-300 ${
provider.key === 'openrouter' ? 'text-center' : ''
<div className={`font-medium text-gray-700 dark:text-gray-300 text-center ${
provider.key === 'openrouter' ? 'text-xs' : 'text-sm'
}`}>
{provider.name}
</div>
@ -842,13 +916,6 @@ export const RAGSettings = ({
);
}
})()}
{(provider.key === 'anthropic' || provider.key === 'grok' || provider.key === 'openrouter') && (
<div className="absolute inset-0 bg-black/20 rounded-lg flex items-center justify-center">
<div className="bg-yellow-500/80 text-black text-xs font-bold px-2 py-1 rounded transform -rotate-12">
Coming Soon
</div>
</div>
)}
</button>
))}
</div>
@ -1223,19 +1290,9 @@ export const RAGSettings = ({
</div>
)}
{ragSettings.LLM_PROVIDER === 'anthropic' && (
<div className="p-4 bg-orange-50 dark:bg-orange-900/20 border border-orange-200 dark:border-orange-800 rounded-lg mb-4">
<p className="text-sm text-orange-800 dark:text-orange-300">
Configure your Anthropic API key in the credentials section to use Claude models.
</p>
</div>
)}
{ragSettings.LLM_PROVIDER === 'groq' && (
<div className="p-4 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg mb-4">
<p className="text-sm text-yellow-800 dark:text-yellow-300">
Groq provides fast inference with Llama, Mixtral, and Gemma models.
</p>
{shouldShowProviderAlert && (
<div className={`p-4 border rounded-lg mb-4 ${providerAlertClassName}`}>
<p className="text-sm">{providerAlertMessage}</p>
</div>
)}
@ -1853,94 +1910,56 @@ export const RAGSettings = ({
function getDisplayedChatModel(ragSettings: any): string {
const provider = ragSettings.LLM_PROVIDER || 'openai';
const modelChoice = ragSettings.MODEL_CHOICE;
// Check if the stored model is appropriate for the current provider
const isModelAppropriate = (model: string, provider: string): boolean => {
if (!model) return false;
switch (provider) {
case 'openai':
return model.startsWith('gpt-') || model.startsWith('o1-') || model.includes('text-davinci') || model.includes('text-embedding');
case 'anthropic':
return model.startsWith('claude-');
case 'google':
return model.startsWith('gemini-') || model.startsWith('text-embedding-');
case 'grok':
return model.startsWith('grok-');
case 'ollama':
return !model.startsWith('gpt-') && !model.startsWith('claude-') && !model.startsWith('gemini-') && !model.startsWith('grok-');
case 'openrouter':
return model.includes('/') || model.startsWith('anthropic/') || model.startsWith('openai/');
default:
return false;
}
};
// Use stored model if it's appropriate for the provider, otherwise use default
const useStoredModel = modelChoice && isModelAppropriate(modelChoice, provider);
// Always prioritize user input to allow editing
if (modelChoice !== undefined && modelChoice !== null) {
return modelChoice;
}
// Only use defaults when there's no stored value
switch (provider) {
case 'openai':
return useStoredModel ? modelChoice : 'gpt-4o-mini';
return 'gpt-4o-mini';
case 'anthropic':
return useStoredModel ? modelChoice : 'claude-3-5-sonnet-20241022';
return 'claude-3-5-sonnet-20241022';
case 'google':
return useStoredModel ? modelChoice : 'gemini-1.5-flash';
return 'gemini-1.5-flash';
case 'grok':
return useStoredModel ? modelChoice : 'grok-2-latest';
return 'grok-3-mini';
case 'ollama':
return useStoredModel ? modelChoice : '';
return '';
case 'openrouter':
return useStoredModel ? modelChoice : 'anthropic/claude-3.5-sonnet';
return 'anthropic/claude-3.5-sonnet';
default:
return useStoredModel ? modelChoice : 'gpt-4o-mini';
return 'gpt-4o-mini';
}
}
function getDisplayedEmbeddingModel(ragSettings: any): string {
const provider = ragSettings.LLM_PROVIDER || 'openai';
const embeddingModel = ragSettings.EMBEDDING_MODEL;
// Check if the stored embedding model is appropriate for the current provider
const isEmbeddingModelAppropriate = (model: string, provider: string): boolean => {
if (!model) return false;
switch (provider) {
case 'openai':
return model.startsWith('text-embedding-') || model.includes('ada-');
case 'anthropic':
return false; // Claude doesn't provide embedding models
case 'google':
return model.startsWith('text-embedding-') || model.startsWith('textembedding-') || model.includes('embedding');
case 'grok':
return false; // Grok doesn't provide embedding models
case 'ollama':
return !model.startsWith('text-embedding-') || model.includes('embed') || model.includes('arctic');
case 'openrouter':
return model.startsWith('text-embedding-') || model.includes('/');
default:
return false;
}
};
// Use stored model if it's appropriate for the provider, otherwise use default
const useStoredModel = embeddingModel && isEmbeddingModelAppropriate(embeddingModel, provider);
// Always prioritize user input to allow editing
if (embeddingModel !== undefined && embeddingModel !== null && embeddingModel !== '') {
return embeddingModel;
}
// Provide appropriate defaults based on LLM provider
switch (provider) {
case 'openai':
return useStoredModel ? embeddingModel : 'text-embedding-3-small';
case 'anthropic':
return 'Not available - Claude does not provide embedding models';
return 'text-embedding-3-small';
case 'google':
return useStoredModel ? embeddingModel : 'text-embedding-004';
case 'grok':
return 'Not available - Grok does not provide embedding models';
return 'text-embedding-004';
case 'ollama':
return useStoredModel ? embeddingModel : '';
return '';
case 'openrouter':
return useStoredModel ? embeddingModel : 'text-embedding-3-small';
return 'text-embedding-3-small'; // Default to OpenAI embedding for OpenRouter
case 'anthropic':
return 'text-embedding-3-small'; // Use OpenAI embeddings with Claude
case 'grok':
return 'text-embedding-3-small'; // Use OpenAI embeddings with Grok
default:
return useStoredModel ? embeddingModel : 'text-embedding-3-small';
return 'text-embedding-3-small';
}
}
@ -2035,4 +2054,4 @@ const CustomCheckbox = ({
</div>
</div>
);
};
};

View File

@ -14,6 +14,7 @@ from .internal_api import router as internal_router
from .knowledge_api import router as knowledge_router
from .mcp_api import router as mcp_router
from .projects_api import router as projects_router
from .providers_api import router as providers_router
from .settings_api import router as settings_router
__all__ = [
@ -23,4 +24,5 @@ __all__ = [
"projects_router",
"agent_chat_router",
"internal_router",
"providers_router",
]

View File

@ -18,6 +18,8 @@ from urllib.parse import urlparse
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from pydantic import BaseModel
# Basic validation - simplified inline version
# Import unified logging
from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
from ..services.crawler_manager import get_crawler
@ -62,26 +64,59 @@ async def _validate_provider_api_key(provider: str = None) -> None:
logger.info("🔑 Starting API key validation...")
try:
# Basic provider validation
if not provider:
provider = "openai"
else:
# Simple provider validation
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
if provider not in allowed_providers:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid provider name",
"message": f"Provider '{provider}' not supported",
"error_type": "validation_error"
}
)
logger.info(f"🔑 Testing {provider.title()} API key with minimal embedding request...")
# Test API key with minimal embedding request - this will fail if key is invalid
from ..services.embeddings.embedding_service import create_embedding
test_result = await create_embedding(text="test")
if not test_result:
logger.error(f"{provider.title()} API key validation failed - no embedding returned")
raise HTTPException(
status_code=401,
detail={
"error": f"Invalid {provider.title()} API key",
"message": f"Please verify your {provider.title()} API key in Settings.",
"error_type": "authentication_failed",
"provider": provider
}
)
# Basic sanitization for logging
safe_provider = provider[:20] # Limit length
logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...")
try:
# Test API key with minimal embedding request using provider-scoped configuration
from ..services.embeddings.embedding_service import create_embedding
test_result = await create_embedding(text="test", provider=provider)
if not test_result:
logger.error(
f"{provider.title()} API key validation failed - no embedding returned"
)
raise HTTPException(
status_code=401,
detail={
"error": f"Invalid {provider.title()} API key",
"message": f"Please verify your {provider.title()} API key in Settings.",
"error_type": "authentication_failed",
"provider": provider,
},
)
except Exception as e:
logger.error(
f"{provider.title()} API key validation failed: {e}",
exc_info=True,
)
raise HTTPException(
status_code=401,
detail={
"error": f"Invalid {provider.title()} API key",
"message": f"Please verify your {provider.title()} API key in Settings. Error: {str(e)[:100]}",
"error_type": "authentication_failed",
"provider": provider,
},
)
logger.info(f"{provider.title()} API key validation successful")

View File

@ -0,0 +1,154 @@
"""
Provider status API endpoints for testing connectivity
Handles server-side provider connectivity testing without exposing API keys to frontend.
"""
import httpx
from fastapi import APIRouter, HTTPException, Path
from ..config.logfire_config import logfire
from ..services.credential_service import credential_service
# Provider validation - simplified inline version
router = APIRouter(prefix="/api/providers", tags=["providers"])
async def test_openai_connection(api_key: str) -> bool:
"""Test OpenAI API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"OpenAI connectivity test failed: {e}")
return False
async def test_google_connection(api_key: str) -> bool:
"""Test Google AI API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://generativelanguage.googleapis.com/v1/models",
headers={"x-goog-api-key": api_key}
)
return response.status_code == 200
except Exception:
logfire.warning("Google AI connectivity test failed")
return False
async def test_anthropic_connection(api_key: str) -> bool:
"""Test Anthropic API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.anthropic.com/v1/models",
headers={
"x-api-key": api_key,
"anthropic-version": "2023-06-01"
}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"Anthropic connectivity test failed: {e}")
return False
async def test_openrouter_connection(api_key: str) -> bool:
"""Test OpenRouter API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://openrouter.ai/api/v1/models",
headers={"Authorization": f"Bearer {api_key}"}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"OpenRouter connectivity test failed: {e}")
return False
async def test_grok_connection(api_key: str) -> bool:
"""Test Grok API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.x.ai/v1/models",
headers={"Authorization": f"Bearer {api_key}"}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"Grok connectivity test failed: {e}")
return False
PROVIDER_TESTERS = {
"openai": test_openai_connection,
"google": test_google_connection,
"anthropic": test_anthropic_connection,
"openrouter": test_openrouter_connection,
"grok": test_grok_connection,
}
@router.get("/{provider}/status")
async def get_provider_status(
provider: str = Path(
...,
description="Provider name to test connectivity for",
regex="^[a-z0-9_]+$",
max_length=20
)
):
"""Test provider connectivity using server-side API key (secure)"""
try:
# Basic provider validation
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
if provider not in allowed_providers:
raise HTTPException(
status_code=400,
detail=f"Invalid provider '{provider}'. Allowed providers: {sorted(allowed_providers)}"
)
# Basic sanitization for logging
safe_provider = provider[:20] # Limit length
logfire.info(f"Testing {safe_provider} connectivity server-side")
if provider not in PROVIDER_TESTERS:
raise HTTPException(
status_code=400,
detail=f"Provider '{provider}' not supported for connectivity testing"
)
# Get API key server-side (never expose to client)
key_name = f"{provider.upper()}_API_KEY"
api_key = await credential_service.get_credential(key_name, decrypt=True)
if not api_key or not isinstance(api_key, str) or not api_key.strip():
logfire.info(f"No API key configured for {safe_provider}")
return {"ok": False, "reason": "no_key"}
# Test connectivity using server-side key
tester = PROVIDER_TESTERS[provider]
is_connected = await tester(api_key)
logfire.info(f"{safe_provider} connectivity test result: {is_connected}")
return {
"ok": is_connected,
"reason": "connected" if is_connected else "connection_failed",
"provider": provider # Echo back validated provider name
}
except HTTPException:
# Re-raise HTTP exceptions (they're already properly formatted)
raise
except Exception as e:
# Basic error sanitization for logging
safe_error = str(e)[:100] # Limit length
logfire.error(f"Error testing {provider[:20]} connectivity: {safe_error}")
raise HTTPException(status_code=500, detail={"error": "Internal server error during connectivity test"})

View File

@ -26,6 +26,7 @@ from .api_routes.mcp_api import router as mcp_router
from .api_routes.ollama_api import router as ollama_router
from .api_routes.progress_api import router as progress_router
from .api_routes.projects_api import router as projects_router
from .api_routes.providers_api import router as providers_router
# Import modular API routers
from .api_routes.settings_api import router as settings_router
@ -186,6 +187,7 @@ app.include_router(progress_router)
app.include_router(agent_chat_router)
app.include_router(internal_router)
app.include_router(bug_report_router)
app.include_router(providers_router)
# Root endpoint

View File

@ -139,6 +139,7 @@ class CodeExtractionService:
source_id: str,
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
) -> int:
"""
Extract code examples from crawled documents and store them.
@ -204,7 +205,7 @@ class CodeExtractionService:
# Generate summaries for code blocks
summary_results = await self._generate_code_summaries(
all_code_blocks, summary_callback, cancellation_check
all_code_blocks, summary_callback, cancellation_check, provider
)
# Prepare code examples for storage
@ -223,7 +224,7 @@ class CodeExtractionService:
# Store code examples in database
return await self._store_code_examples(
storage_data, url_to_full_document, storage_callback
storage_data, url_to_full_document, storage_callback, provider
)
async def _extract_code_blocks_from_documents(
@ -1523,6 +1524,7 @@ class CodeExtractionService:
all_code_blocks: list[dict[str, Any]],
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
) -> list[dict[str, str]]:
"""
Generate summaries for all code blocks.
@ -1587,7 +1589,7 @@ class CodeExtractionService:
try:
results = await generate_code_summaries_batch(
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback, provider=provider
)
# Ensure all results are valid dicts
@ -1667,6 +1669,7 @@ class CodeExtractionService:
storage_data: dict[str, list[Any]],
url_to_full_document: dict[str, str],
progress_callback: Callable | None = None,
provider: str | None = None,
) -> int:
"""
Store code examples in the database.
@ -1709,7 +1712,7 @@ class CodeExtractionService:
batch_size=20,
url_to_full_document=url_to_full_document,
progress_callback=storage_progress_callback,
provider=None, # Use configured provider
provider=provider,
)
# Report completion of code extraction/storage phase

View File

@ -475,12 +475,24 @@ class CrawlingService:
)
try:
# Extract provider from request or use credential service default
provider = request.get("provider")
if not provider:
try:
from ..credential_service import credential_service
provider_config = await credential_service.get_active_provider("llm")
provider = provider_config.get("provider", "openai")
except Exception as e:
logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
provider = "openai"
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
crawl_results,
storage_results["url_to_full_document"],
storage_results["source_id"],
code_progress_callback,
self._check_cancellation,
provider,
)
except RuntimeError as e:
# Code extraction failed, continue crawl with warning

View File

@ -351,6 +351,7 @@ class DocumentStorageOperations:
source_id: str,
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
) -> int:
"""
Extract code examples from crawled documents and store them.
@ -361,12 +362,13 @@ class DocumentStorageOperations:
source_id: The unique source_id for all documents
progress_callback: Optional callback for progress updates
cancellation_check: Optional function to check for cancellation
provider: Optional LLM provider to use for code summaries
Returns:
Number of code examples stored
"""
result = await self.code_extraction_service.extract_and_store_code_examples(
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider
)
return result

View File

@ -36,6 +36,44 @@ class CredentialItem:
description: str | None = None
def _detect_embedding_provider_from_model(embedding_model: str) -> str:
"""
Detect the appropriate embedding provider based on model name.
Args:
embedding_model: The embedding model name
Returns:
Provider name: 'google', 'openai', or 'openai' (default)
"""
if not embedding_model:
return "openai" # Default
model_lower = embedding_model.lower()
# Google embedding models
google_patterns = [
"text-embedding-004",
"text-embedding-005",
"text-multilingual-embedding",
"gemini-embedding",
"multimodalembedding"
]
if any(pattern in model_lower for pattern in google_patterns):
return "google"
# OpenAI embedding models (and default for unknown)
openai_patterns = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large"
]
# Default to OpenAI for OpenAI models or unknown models
return "openai"
class CredentialService:
"""Service for managing application credentials and configuration."""
@ -239,6 +277,14 @@ class CredentialService:
self._rag_cache_timestamp = None
logger.debug(f"Invalidated RAG settings cache due to update of {key}")
# Also invalidate provider service cache to ensure immediate effect
try:
from .llm_provider_service import clear_provider_cache
clear_provider_cache()
logger.debug("Also cleared LLM provider service cache")
except Exception as e:
logger.warning(f"Failed to clear provider service cache: {e}")
# Also invalidate LLM provider service cache for provider config
try:
from . import llm_provider_service
@ -281,6 +327,14 @@ class CredentialService:
self._rag_cache_timestamp = None
logger.debug(f"Invalidated RAG settings cache due to deletion of {key}")
# Also invalidate provider service cache to ensure immediate effect
try:
from .llm_provider_service import clear_provider_cache
clear_provider_cache()
logger.debug("Also cleared LLM provider service cache")
except Exception as e:
logger.warning(f"Failed to clear provider service cache: {e}")
# Also invalidate LLM provider service cache for provider config
try:
from . import llm_provider_service
@ -419,8 +473,33 @@ class CredentialService:
# Get RAG strategy settings (where UI saves provider selection)
rag_settings = await self.get_credentials_by_category("rag_strategy")
# Get the selected provider
provider = rag_settings.get("LLM_PROVIDER", "openai")
# Get the selected provider based on service type
if service_type == "embedding":
# Get the LLM provider setting to determine embedding provider
llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small")
# Determine embedding provider based on LLM provider
if llm_provider == "google":
provider = "google"
elif llm_provider == "ollama":
provider = "ollama"
elif llm_provider == "openrouter":
# OpenRouter supports both OpenAI and Google embedding models
provider = _detect_embedding_provider_from_model(embedding_model)
elif llm_provider in ["anthropic", "grok"]:
# Anthropic and Grok support both OpenAI and Google embedding models
provider = _detect_embedding_provider_from_model(embedding_model)
else:
# Default case (openai, or unknown providers)
provider = "openai"
logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'")
else:
provider = rag_settings.get("LLM_PROVIDER", "openai")
# Ensure provider is a valid string, not a boolean or other type
if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"):
provider = "openai"
# Get API key for this provider
api_key = await self._get_provider_api_key(provider)
@ -464,6 +543,9 @@ class CredentialService:
key_mapping = {
"openai": "OPENAI_API_KEY",
"google": "GOOGLE_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"grok": "GROK_API_KEY",
"ollama": None, # No API key needed
}
@ -478,6 +560,12 @@ class CredentialService:
return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1")
elif provider == "google":
return "https://generativelanguage.googleapis.com/v1beta/openai/"
elif provider == "openrouter":
return "https://openrouter.ai/api/v1"
elif provider == "anthropic":
return "https://api.anthropic.com/v1"
elif provider == "grok":
return "https://api.x.ai/v1"
return None # Use default for OpenAI
async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool:
@ -485,7 +573,7 @@ class CredentialService:
try:
# For now, we'll update the RAG strategy settings
return await self.set_credential(
"llm_provider",
"LLM_PROVIDER",
provider,
category="rag_strategy",
description=f"Active {service_type} provider",

View File

@ -10,7 +10,13 @@ import os
import openai
from ...config.logfire_config import search_logger
from ..llm_provider_service import get_llm_client
from ..credential_service import credential_service
from ..llm_provider_service import (
extract_message_text,
get_llm_client,
prepare_chat_completion_params,
requires_max_completion_tokens,
)
from ..threading_service import get_threading_service
@ -32,8 +38,6 @@ async def generate_contextual_embedding(
"""
# Model choice is a RAG setting, get from credential service
try:
from ...services.credential_service import credential_service
model_choice = await credential_service.get_credential("MODEL_CHOICE", "gpt-4.1-nano")
except Exception as e:
# Fallback to environment variable or default
@ -65,20 +69,25 @@ Please give a short succinct context to situate this chunk within the overall do
# Get model from provider configuration
model = await _get_model_choice(provider)
response = await client.chat.completions.create(
model=model,
messages=[
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
params = {
"model": model,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that provides concise contextual information.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
max_tokens=200,
)
"temperature": 0.3,
"max_tokens": 1200 if requires_max_completion_tokens(model) else 200, # Much more tokens for reasoning models (GPT-5 needs extra for reasoning process)
}
final_params = prepare_chat_completion_params(model, params)
response = await client.chat.completions.create(**final_params)
context = response.choices[0].message.content.strip()
choice = response.choices[0] if response.choices else None
context, _, _ = extract_message_text(choice)
context = context.strip()
contextual_text = f"{context}\n---\n{chunk}"
return contextual_text, True
@ -111,7 +120,7 @@ async def process_chunk_with_context(
async def _get_model_choice(provider: str | None = None) -> str:
"""Get model choice from credential service."""
"""Get model choice from credential service with centralized defaults."""
from ..credential_service import credential_service
# Get the active provider configuration
@ -119,31 +128,36 @@ async def _get_model_choice(provider: str | None = None) -> str:
model = provider_config.get("chat_model", "").strip() # Strip whitespace
provider_name = provider_config.get("provider", "openai")
# Handle empty model case - fallback to provider-specific defaults or explicit config
# Handle empty model case - use centralized defaults
if not model:
search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic")
search_logger.warning(f"chat_model is empty for provider {provider_name}, using centralized defaults")
# Special handling for Ollama to check specific credential
if provider_name == "ollama":
# Try to get OLLAMA_CHAT_MODEL specifically
try:
ollama_model = await credential_service.get_credential("OLLAMA_CHAT_MODEL")
if ollama_model and ollama_model.strip():
model = ollama_model.strip()
search_logger.info(f"Using OLLAMA_CHAT_MODEL fallback: {model}")
else:
# Use a sensible Ollama default
# Use default for Ollama
model = "llama3.2:latest"
search_logger.info(f"Using Ollama default model: {model}")
search_logger.info(f"Using Ollama default: {model}")
except Exception as e:
search_logger.error(f"Error getting OLLAMA_CHAT_MODEL: {e}")
model = "llama3.2:latest"
search_logger.info(f"Using Ollama fallback model: {model}")
elif provider_name == "google":
model = "gemini-1.5-flash"
search_logger.info(f"Using Ollama fallback: {model}")
else:
# OpenAI or other providers
model = "gpt-4o-mini"
# Use provider-specific defaults
provider_defaults = {
"openai": "gpt-4o-mini",
"openrouter": "anthropic/claude-3.5-sonnet",
"google": "gemini-1.5-flash",
"anthropic": "claude-3-5-haiku-20241022",
"grok": "grok-3-mini"
}
model = provider_defaults.get(provider_name, "gpt-4o-mini")
search_logger.debug(f"Using default model for provider {provider_name}: {model}")
search_logger.debug(f"Using model from credential service: {model}")
return model
@ -174,38 +188,48 @@ async def generate_contextual_embeddings_batch(
model_choice = await _get_model_choice(provider)
# Build batch prompt for ALL chunks at once
batch_prompt = (
"Process the following chunks and provide contextual information for each:\\n\\n"
)
batch_prompt = "Process the following chunks and provide contextual information for each:\n\n"
for i, (doc, chunk) in enumerate(zip(full_documents, chunks, strict=False)):
# Use only 2000 chars of document context to save tokens
doc_preview = doc[:2000] if len(doc) > 2000 else doc
batch_prompt += f"CHUNK {i + 1}:\\n"
batch_prompt += f"<document_preview>\\n{doc_preview}\\n</document_preview>\\n"
batch_prompt += f"<chunk>\\n{chunk[:500]}\\n</chunk>\\n\\n" # Limit chunk preview
batch_prompt += f"CHUNK {i + 1}:\n"
batch_prompt += f"<document_preview>\n{doc_preview}\n</document_preview>\n"
batch_prompt += f"<chunk>\n{chunk[:500]}\n</chunk>\n\n" # Limit chunk preview
batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc."
batch_prompt += (
"For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. "
"Format your response as:\nCHUNK 1: [context]\nCHUNK 2: [context]\netc."
)
# Make single API call for ALL chunks
response = await client.chat.completions.create(
model=model_choice,
messages=[
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
batch_params = {
"model": model_choice,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that generates contextual information for document chunks.",
},
{"role": "user", "content": batch_prompt},
],
temperature=0,
max_tokens=100 * len(chunks), # Limit response size
)
"temperature": 0,
"max_tokens": (600 if requires_max_completion_tokens(model_choice) else 100) * len(chunks), # Much more tokens for reasoning models (GPT-5 needs extra reasoning space)
}
final_batch_params = prepare_chat_completion_params(model_choice, batch_params)
response = await client.chat.completions.create(**final_batch_params)
# Parse response
response_text = response.choices[0].message.content
choice = response.choices[0] if response.choices else None
response_text, _, _ = extract_message_text(choice)
if not response_text:
search_logger.error(
"Empty response from LLM when generating contextual embeddings batch"
)
return [(chunk, False) for chunk in chunks]
# Extract contexts from response
lines = response_text.strip().split("\\n")
lines = response_text.strip().split("\n")
chunk_contexts = {}
for line in lines:
@ -245,4 +269,4 @@ async def generate_contextual_embeddings_batch(
except Exception as e:
search_logger.error(f"Error in contextual embedding batch: {e}")
# Return non-contextual for all chunks
return [(chunk, False) for chunk in chunks]
return [(chunk, False) for chunk in chunks]

View File

@ -13,7 +13,7 @@ import openai
from ...config.logfire_config import safe_span, search_logger
from ..credential_service import credential_service
from ..llm_provider_service import get_embedding_model, get_llm_client
from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model
from ..threading_service import get_threading_service
from .embedding_exceptions import (
EmbeddingAPIError,
@ -152,34 +152,56 @@ async def create_embeddings_batch(
if not texts:
return EmbeddingBatchResult()
result = EmbeddingBatchResult()
# Validate that all items in texts are strings
validated_texts = []
for i, text in enumerate(texts):
if not isinstance(text, str):
search_logger.error(
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
)
# Try to convert to string
try:
validated_texts.append(str(text))
except Exception as e:
search_logger.error(
f"Failed to convert text at index {i} to string: {e}", exc_info=True
)
validated_texts.append("") # Use empty string as fallback
else:
if isinstance(text, str):
validated_texts.append(text)
continue
search_logger.error(
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
)
try:
converted = str(text)
validated_texts.append(converted)
except Exception as conversion_error:
search_logger.error(
f"Failed to convert text at index {i} to string: {conversion_error}",
exc_info=True,
)
result.add_failure(
repr(text),
EmbeddingAPIError("Invalid text type", original_error=conversion_error),
batch_index=None,
)
texts = validated_texts
result = EmbeddingBatchResult()
threading_service = get_threading_service()
with safe_span(
"create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
) as span:
try:
async with get_llm_client(provider=provider, use_embedding_provider=True) as client:
# Intelligent embedding provider routing based on model type
# Get the embedding model first to determine the correct provider
embedding_model = await get_embedding_model(provider=provider)
# Route to correct provider based on model type
if is_google_embedding_model(embedding_model):
embedding_provider = "google"
search_logger.info(f"Routing to Google for embedding model: {embedding_model}")
elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower():
embedding_provider = "openai"
search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}")
else:
# Keep original provider for ollama and other providers
embedding_provider = provider
search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}")
async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
# Load batch size and dimensions from settings
try:
rag_settings = await credential_service.get_credentials_by_category(
@ -220,7 +242,8 @@ async def create_embeddings_batch(
while retry_count < max_retries:
try:
# Create embeddings for this batch
embedding_model = await get_embedding_model(provider=provider)
embedding_model = await get_embedding_model(provider=embedding_provider)
response = await client.embeddings.create(
model=embedding_model,
input=batch,

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@
Provider Discovery Service
Discovers available models, checks provider health, and provides model specifications
for OpenAI, Google Gemini, Ollama, and Anthropic providers.
for OpenAI, Google Gemini, Ollama, Anthropic, and Grok providers.
"""
import time
@ -359,6 +359,36 @@ class ProviderDiscoveryService:
return models
async def discover_grok_models(self, api_key: str) -> list[ModelSpec]:
"""Discover available Grok models."""
cache_key = f"grok_models_{hash(api_key)}"
cached = self._get_cached_result(cache_key)
if cached:
return cached
models = []
try:
# Grok model specifications
model_specs = [
ModelSpec("grok-3-mini", "grok", 32768, True, True, False, None, 0.15, 0.60, "Fast and efficient Grok model"),
ModelSpec("grok-3", "grok", 32768, True, True, False, None, 2.00, 10.00, "Standard Grok model"),
ModelSpec("grok-4", "grok", 32768, True, True, False, None, 5.00, 25.00, "Advanced Grok model"),
ModelSpec("grok-2-vision", "grok", 8192, True, True, True, None, 3.00, 15.00, "Grok model with vision capabilities"),
ModelSpec("grok-2-latest", "grok", 8192, True, True, False, None, 2.00, 10.00, "Latest Grok 2 model"),
]
# Test connectivity - Grok doesn't have a models list endpoint,
# so we'll just return the known models if API key is provided
if api_key:
models = model_specs
self._cache_result(cache_key, models)
logger.info(f"Discovered {len(models)} Grok models")
except Exception as e:
logger.error(f"Error discovering Grok models: {e}")
return models
async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus:
"""Check health and connectivity status of a provider."""
start_time = time.time()
@ -456,6 +486,23 @@ class ProviderDiscoveryService:
last_checked=time.time()
)
elif provider == "grok":
api_key = config.get("api_key")
if not api_key:
return ProviderStatus(provider, False, None, "API key not configured")
# Grok doesn't have a health check endpoint, so we'll assume it's available
# if API key is provided. In a real implementation, you might want to make a
# small test request to verify the key is valid.
response_time = (time.time() - start_time) * 1000
return ProviderStatus(
provider="grok",
is_available=True,
response_time_ms=response_time,
models_available=5, # Known model count
last_checked=time.time()
)
else:
return ProviderStatus(provider, False, None, f"Unknown provider: {provider}")
@ -496,6 +543,11 @@ class ProviderDiscoveryService:
if anthropic_key:
providers["anthropic"] = await self.discover_anthropic_models(anthropic_key)
# Grok
grok_key = await credential_service.get_credential("GROK_API_KEY")
if grok_key:
providers["grok"] = await self.discover_grok_models(grok_key)
except Exception as e:
logger.error(f"Error getting all available models: {e}")

View File

@ -11,7 +11,7 @@ from supabase import Client
from ..config.logfire_config import get_logger, search_logger
from .client_manager import get_supabase_client
from .llm_provider_service import get_llm_client
from .llm_provider_service import extract_message_text, get_llm_client
logger = get_logger(__name__)
@ -72,20 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a
)
# Extract the generated summary with proper error handling
if not response or not response.choices or len(response.choices) == 0:
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
return default_summary
message_content = response.choices[0].message.content
if message_content is None:
search_logger.error(f"LLM returned None content for {source_id}")
return default_summary
summary = message_content.strip()
# Ensure the summary is not too long
if len(summary) > max_length:
summary = summary[:max_length] + "..."
if not response or not response.choices or len(response.choices) == 0:
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
return default_summary
choice = response.choices[0]
summary_text, _, _ = extract_message_text(choice)
if not summary_text:
search_logger.error(f"LLM returned None content for {source_id}")
return default_summary
summary = summary_text.strip()
# Ensure the summary is not too long
if len(summary) > max_length:
summary = summary[:max_length] + "..."
return summary
@ -187,7 +188,9 @@ Generate only the title, nothing else."""
],
)
generated_title = response.choices[0].message.content.strip()
choice = response.choices[0]
generated_title, _, _ = extract_message_text(choice)
generated_title = generated_title.strip()
# Clean up the title
generated_title = generated_title.strip("\"'")
if len(generated_title) < 50: # Sanity check

View File

@ -8,6 +8,7 @@ import asyncio
import json
import os
import re
import time
from collections import defaultdict, deque
from collections.abc import Callable
from difflib import SequenceMatcher
@ -17,25 +18,104 @@ from urllib.parse import urlparse
from supabase import Client
from ...config.logfire_config import search_logger
from ..credential_service import credential_service
from ..embeddings.contextual_embedding_service import generate_contextual_embeddings_batch
from ..embeddings.embedding_service import create_embeddings_batch
from ..llm_provider_service import (
extract_json_from_reasoning,
extract_message_text,
get_llm_client,
prepare_chat_completion_params,
synthesize_json_from_reasoning,
)
def _get_model_choice() -> str:
"""Get MODEL_CHOICE with direct fallback."""
def _extract_json_payload(raw_response: str, context_code: str = "", language: str = "") -> str:
"""Return the best-effort JSON object from an LLM response."""
if not raw_response:
return raw_response
cleaned = raw_response.strip()
# Check if this looks like reasoning text first
if _is_reasoning_text_response(cleaned):
# Try intelligent extraction from reasoning text with context
extracted = extract_json_from_reasoning(cleaned, context_code, language)
if extracted:
return extracted
# extract_json_from_reasoning may return nothing; synthesize a fallback JSON if so\
fallback_json = synthesize_json_from_reasoning("", context_code, language)
if fallback_json:
return fallback_json
# If all else fails, return a minimal valid JSON object to avoid downstream errors
return '{"example_name": "Code Example", "summary": "Code example extracted from context."}'
if cleaned.startswith("```"):
lines = cleaned.splitlines()
# Drop opening fence
lines = lines[1:]
# Drop closing fence if present
if lines and lines[-1].strip().startswith("```"):
lines = lines[:-1]
cleaned = "\n".join(lines).strip()
# Trim any leading/trailing text outside the outermost JSON braces
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end >= start:
cleaned = cleaned[start : end + 1]
return cleaned.strip()
REASONING_STARTERS = [
"okay, let's see", "okay, let me", "let me think", "first, i need to", "looking at this",
"i need to", "analyzing", "let me work through", "thinking about", "let me see"
]
def _is_reasoning_text_response(text: str) -> bool:
"""Detect if response is reasoning text rather than direct JSON."""
if not text or len(text) < 20:
return False
text_lower = text.lower().strip()
# Check if it's clearly not JSON (starts with reasoning text)
starts_with_reasoning = any(text_lower.startswith(starter) for starter in REASONING_STARTERS)
# Check if it lacks immediate JSON structure
lacks_immediate_json = not text_lower.lstrip().startswith('{')
return starts_with_reasoning or (lacks_immediate_json and any(pattern in text_lower for pattern in REASONING_STARTERS))
async def _get_model_choice() -> str:
"""Get MODEL_CHOICE with provider-aware defaults from centralized service."""
try:
# Direct cache/env fallback
from ..credential_service import credential_service
# Get the active provider configuration
provider_config = await credential_service.get_active_provider("llm")
active_provider = provider_config.get("provider", "openai")
model = provider_config.get("chat_model")
if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache:
model = credential_service._cache["MODEL_CHOICE"]
else:
model = os.getenv("MODEL_CHOICE", "gpt-4.1-nano")
search_logger.debug(f"Using model choice: {model}")
# If no custom model is set, use provider-specific defaults
if not model or model.strip() == "":
# Provider-specific defaults
provider_defaults = {
"openai": "gpt-4o-mini",
"openrouter": "anthropic/claude-3.5-sonnet",
"google": "gemini-1.5-flash",
"ollama": "llama3.2:latest",
"anthropic": "claude-3-5-haiku-20241022",
"grok": "grok-3-mini"
}
model = provider_defaults.get(active_provider, "gpt-4o-mini")
search_logger.debug(f"Using default model for provider {active_provider}: {model}")
search_logger.debug(f"Using model for provider {active_provider}: {model}")
return model
except Exception as e:
search_logger.warning(f"Error getting model choice: {e}, using default")
return "gpt-4.1-nano"
return "gpt-4o-mini"
def _get_max_workers() -> int:
@ -155,6 +235,7 @@ def _select_best_code_variant(similar_blocks: list[dict[str, Any]]) -> dict[str,
return best_block
def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[dict[str, Any]]:
"""
Extract code blocks from markdown content along with context.
@ -168,8 +249,6 @@ def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[d
"""
# Load all code extraction settings with direct fallback
try:
from ...services.credential_service import credential_service
def _get_setting_fallback(key: str, default: str) -> str:
if credential_service._cache_initialized and key in credential_service._cache:
return credential_service._cache[key]
@ -507,7 +586,7 @@ def generate_code_example_summary(
A dictionary with 'summary' and 'example_name'
"""
import asyncio
# Run the async version in the current thread
return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider))
@ -518,13 +597,22 @@ async def _generate_code_example_summary_async(
"""
Async version of generate_code_example_summary using unified LLM provider service.
"""
from ..llm_provider_service import get_llm_client
# Get model choice from credential service (RAG setting)
model_choice = _get_model_choice()
# Create the prompt
prompt = f"""<context_before>
# Get model choice from credential service (RAG setting)
model_choice = await _get_model_choice()
# If provider is not specified, get it from credential service
if provider is None:
try:
provider_config = await credential_service.get_active_provider("llm")
provider = provider_config.get("provider", "openai")
search_logger.debug(f"Auto-detected provider from credential service: {provider}")
except Exception as e:
search_logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
provider = "openai"
# Create the prompt variants: base prompt, guarded prompt (JSON reminder), and strict prompt for retries
base_prompt = f"""<context_before>
{context_before[-500:] if len(context_before) > 500 else context_before}
</context_before>
@ -548,6 +636,16 @@ Format your response as JSON:
"summary": "2-3 sentence description of what the code demonstrates"
}}
"""
guard_prompt = (
base_prompt
+ "\n\nImportant: Respond with a valid JSON object that exactly matches the keys "
'{"example_name": string, "summary": string}. Do not include commentary, '
"markdown fences, or reasoning notes."
)
strict_prompt = (
guard_prompt
+ "\n\nSecond attempt enforcement: Return JSON only with the exact schema. No additional text or reasoning content."
)
try:
# Use unified LLM provider service
@ -555,25 +653,261 @@ Format your response as JSON:
search_logger.info(
f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}"
)
response = await client.chat.completions.create(
model=model_choice,
messages=[
{
"role": "system",
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
max_tokens=500,
temperature=0.3,
provider_lower = provider.lower()
is_grok_model = (provider_lower == "grok") or ("grok" in model_choice.lower())
supports_response_format_base = (
provider_lower in {"openai", "google", "anthropic"}
or (provider_lower == "openrouter" and model_choice.startswith("openai/"))
)
response_content = response.choices[0].message.content.strip()
last_response_obj = None
last_elapsed_time = None
last_response_content = ""
last_json_error: json.JSONDecodeError | None = None
for enforce_json, current_prompt in ((False, guard_prompt), (True, strict_prompt)):
request_params = {
"model": model_choice,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
},
{"role": "user", "content": current_prompt},
],
"max_tokens": 2000,
"temperature": 0.3,
}
should_use_response_format = False
if enforce_json:
if not is_grok_model and (supports_response_format_base or provider_lower == "openrouter"):
should_use_response_format = True
else:
if supports_response_format_base:
should_use_response_format = True
if should_use_response_format:
request_params["response_format"] = {"type": "json_object"}
if is_grok_model:
unsupported_params = ["presence_penalty", "frequency_penalty", "stop", "reasoning_effort"]
for param in unsupported_params:
if param in request_params:
removed_value = request_params.pop(param)
search_logger.warning(f"Removed unsupported Grok parameter '{param}': {removed_value}")
supported_params = ["model", "messages", "max_tokens", "temperature", "response_format", "stream", "tools", "tool_choice"]
for param in list(request_params.keys()):
if param not in supported_params:
search_logger.warning(f"Parameter '{param}' may not be supported by Grok reasoning models")
start_time = time.time()
max_retries = 3 if is_grok_model else 1
retry_delay = 1.0
response_content_local = ""
reasoning_text_local = ""
json_error_occurred = False
for attempt in range(max_retries):
try:
if is_grok_model and attempt > 0:
search_logger.info(f"Grok retry attempt {attempt + 1}/{max_retries} after {retry_delay:.1f}s delay")
await asyncio.sleep(retry_delay)
final_params = prepare_chat_completion_params(model_choice, request_params)
response = await client.chat.completions.create(**final_params)
last_response_obj = response
choice = response.choices[0] if response.choices else None
message = choice.message if choice and hasattr(choice, "message") else None
response_content_local = ""
reasoning_text_local = ""
if choice:
response_content_local, reasoning_text_local, _ = extract_message_text(choice)
# Enhanced logging for response analysis
if message and reasoning_text_local:
content_preview = response_content_local[:100] if response_content_local else "None"
reasoning_preview = reasoning_text_local[:100] if reasoning_text_local else "None"
search_logger.debug(
f"Response has reasoning content - content: '{content_preview}', reasoning: '{reasoning_preview}'"
)
if response_content_local:
last_response_content = response_content_local.strip()
# Pre-validate response before processing
if len(last_response_content) < 20 or (len(last_response_content) < 50 and not last_response_content.strip().startswith('{')):
# Very minimal response - likely "Okay\nOkay" type
search_logger.debug(f"Minimal response detected: {repr(last_response_content)}")
# Generate fallback directly from context
fallback_json = synthesize_json_from_reasoning("", code, language)
if fallback_json:
try:
result = json.loads(fallback_json)
final_result = {
"example_name": result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
"summary": result.get("summary", "Code example for demonstration purposes."),
}
search_logger.info(f"Generated fallback summary from context - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
return final_result
except json.JSONDecodeError:
pass # Continue to normal error handling
else:
# Even synthesis failed - provide hardcoded fallback for minimal responses
final_result = {
"example_name": f"Code Example{f' ({language})' if language else ''}",
"summary": "Code example extracted from development context.",
}
search_logger.info(f"Used hardcoded fallback for minimal response - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
return final_result
payload = _extract_json_payload(last_response_content, code, language)
if payload != last_response_content:
search_logger.debug(
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
)
try:
result = json.loads(payload)
if not result.get("example_name") or not result.get("summary"):
search_logger.warning(f"Incomplete response from LLM: {result}")
final_result = {
"example_name": result.get(
"example_name", f"Code Example{f' ({language})' if language else ''}"
),
"summary": result.get("summary", "Code example for demonstration purposes."),
}
search_logger.info(
f"Generated code example summary - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}"
)
return final_result
except json.JSONDecodeError as json_error:
last_json_error = json_error
json_error_occurred = True
snippet = last_response_content[:200]
if not enforce_json:
# Check if this was reasoning text that couldn't be parsed
if _is_reasoning_text_response(last_response_content):
search_logger.debug(
f"Reasoning text detected but no JSON extracted. Response snippet: {repr(snippet)}"
)
else:
search_logger.warning(
f"Failed to parse JSON response from LLM (non-strict attempt). Error: {json_error}. Response snippet: {repr(snippet)}"
)
break
else:
search_logger.error(
f"Strict JSON enforcement still failed to produce valid JSON: {json_error}. Response snippet: {repr(snippet)}"
)
break
elif is_grok_model and attempt < max_retries - 1:
search_logger.warning(f"Grok empty response on attempt {attempt + 1}, retrying...")
retry_delay *= 2
continue
else:
break
except Exception as e:
if is_grok_model and attempt < max_retries - 1:
search_logger.error(f"Grok request failed on attempt {attempt + 1}: {e}, retrying...")
retry_delay *= 2
continue
else:
raise
if is_grok_model:
elapsed_time = time.time() - start_time
last_elapsed_time = elapsed_time
search_logger.debug(f"Grok total response time: {elapsed_time:.2f}s")
if json_error_occurred:
if not enforce_json:
continue
else:
break
if response_content_local:
# We would have returned already on success; if we reach here, parsing failed but we are not retrying
continue
response_content = last_response_content
response = last_response_obj
elapsed_time = last_elapsed_time if last_elapsed_time is not None else 0.0
if last_json_error is not None and response_content:
search_logger.error(
f"LLM response after strict enforcement was still not valid JSON: {last_json_error}. Clearing response to trigger error handling."
)
response_content = ""
if not response_content:
search_logger.error(f"Empty response from LLM for model: {model_choice} (provider: {provider})")
if is_grok_model:
search_logger.error("Grok empty response debugging:")
search_logger.error(f" - Request took: {elapsed_time:.2f}s")
search_logger.error(f" - Response status: {getattr(response, 'status_code', 'N/A')}")
search_logger.error(f" - Response headers: {getattr(response, 'headers', 'N/A')}")
search_logger.error(f" - Full response: {response}")
search_logger.error(f" - Response choices length: {len(response.choices) if response.choices else 0}")
if response.choices:
search_logger.error(f" - First choice: {response.choices[0]}")
search_logger.error(f" - Message content: '{response.choices[0].message.content}'")
search_logger.error(f" - Message role: {response.choices[0].message.role}")
search_logger.error("Check: 1) API key validity, 2) rate limits, 3) model availability")
# Implement fallback for Grok failures
search_logger.warning("Attempting fallback to OpenAI due to Grok failure...")
try:
# Use OpenAI as fallback with similar parameters
fallback_params = {
"model": "gpt-4o-mini",
"messages": request_params["messages"],
"temperature": request_params.get("temperature", 0.1),
"max_tokens": request_params.get("max_tokens", 500),
}
async with get_llm_client(provider="openai") as fallback_client:
fallback_response = await fallback_client.chat.completions.create(**fallback_params)
fallback_content = fallback_response.choices[0].message.content
if fallback_content and fallback_content.strip():
search_logger.info("gpt-4o-mini fallback succeeded")
response_content = fallback_content.strip()
else:
search_logger.error("gpt-4o-mini fallback also returned empty response")
raise ValueError(f"Both {model_choice} and gpt-4o-mini fallback failed")
except Exception as fallback_error:
search_logger.error(f"gpt-4o-mini fallback failed: {fallback_error}")
raise ValueError(f"{model_choice} failed and fallback to gpt-4o-mini also failed: {fallback_error}") from fallback_error
else:
search_logger.debug(f"Full response object: {response}")
raise ValueError("Empty response from LLM")
if not response_content:
# This should not happen after fallback logic, but safety check
raise ValueError("No valid response content after all attempts")
response_content = response_content.strip()
search_logger.debug(f"LLM API response: {repr(response_content[:200])}...")
result = json.loads(response_content)
payload = _extract_json_payload(response_content, code, language)
if payload != response_content:
search_logger.debug(
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
)
result = json.loads(payload)
# Validate the response has the required fields
if not result.get("example_name") or not result.get("summary"):
@ -595,12 +929,38 @@ Format your response as JSON:
search_logger.error(
f"Failed to parse JSON response from LLM: {e}, Response: {repr(response_content) if 'response_content' in locals() else 'No response'}"
)
# Try to generate context-aware fallback
try:
fallback_json = synthesize_json_from_reasoning("", code, language)
if fallback_json:
fallback_result = json.loads(fallback_json)
search_logger.info(f"Generated context-aware fallback summary")
return {
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
}
except Exception:
pass # Fall through to generic fallback
return {
"example_name": f"Code Example{f' ({language})' if language else ''}",
"summary": "Code example for demonstration purposes.",
}
except Exception as e:
search_logger.error(f"Error generating code summary using unified LLM provider: {e}")
# Try to generate context-aware fallback
try:
fallback_json = synthesize_json_from_reasoning("", code, language)
if fallback_json:
fallback_result = json.loads(fallback_json)
search_logger.info(f"Generated context-aware fallback summary after error")
return {
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
}
except Exception:
pass # Fall through to generic fallback
return {
"example_name": f"Code Example{f' ({language})' if language else ''}",
"summary": "Code example for demonstration purposes.",
@ -608,7 +968,7 @@ Format your response as JSON:
async def generate_code_summaries_batch(
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None, provider: str = None
) -> list[dict[str, str]]:
"""
Generate summaries for multiple code blocks with rate limiting and proper worker management.
@ -617,6 +977,7 @@ async def generate_code_summaries_batch(
code_blocks: List of code block dictionaries
max_workers: Maximum number of concurrent API requests
progress_callback: Optional callback for progress updates (async function)
provider: LLM provider to use for generation (e.g., 'grok', 'openai', 'anthropic')
Returns:
List of summary dictionaries
@ -627,8 +988,6 @@ async def generate_code_summaries_batch(
# Get max_workers from settings if not provided
if max_workers is None:
try:
from ...services.credential_service import credential_service
if (
credential_service._cache_initialized
and "CODE_SUMMARY_MAX_WORKERS" in credential_service._cache
@ -663,6 +1022,7 @@ async def generate_code_summaries_batch(
block["context_before"],
block["context_after"],
block.get("language", ""),
provider,
)
# Update progress
@ -757,29 +1117,17 @@ async def add_code_examples_to_supabase(
except Exception as e:
search_logger.error(f"Error deleting existing code examples for {url}: {e}")
# Check if contextual embeddings are enabled
# Check if contextual embeddings are enabled (use proper async method like document storage)
try:
from ..credential_service import credential_service
use_contextual_embeddings = credential_service._cache.get("USE_CONTEXTUAL_EMBEDDINGS")
if isinstance(use_contextual_embeddings, str):
use_contextual_embeddings = use_contextual_embeddings.lower() == "true"
elif isinstance(use_contextual_embeddings, dict) and use_contextual_embeddings.get(
"is_encrypted"
):
# Handle encrypted value
encrypted_value = use_contextual_embeddings.get("encrypted_value")
if encrypted_value:
try:
decrypted = credential_service._decrypt_value(encrypted_value)
use_contextual_embeddings = decrypted.lower() == "true"
except:
use_contextual_embeddings = False
else:
use_contextual_embeddings = False
raw_value = await credential_service.get_credential(
"USE_CONTEXTUAL_EMBEDDINGS", "false", decrypt=True
)
if isinstance(raw_value, str):
use_contextual_embeddings = raw_value.lower() == "true"
else:
use_contextual_embeddings = bool(use_contextual_embeddings)
except:
use_contextual_embeddings = bool(raw_value)
except Exception as e:
search_logger.error(f"DEBUG: Error reading contextual embeddings: {e}")
# Fallback to environment variable
use_contextual_embeddings = (
os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false").lower() == "true"
@ -848,14 +1196,13 @@ async def add_code_examples_to_supabase(
# Use only successful embeddings
valid_embeddings = result.embeddings
successful_texts = result.texts_processed
# Get model information for tracking
from ..llm_provider_service import get_embedding_model
from ..credential_service import credential_service
# Get embedding model name
embedding_model_name = await get_embedding_model(provider=provider)
# Get LLM chat model (used for code summaries and contextual embeddings if enabled)
llm_chat_model = None
try:
@ -868,7 +1215,7 @@ async def add_code_examples_to_supabase(
llm_chat_model = await credential_service.get_credential("MODEL_CHOICE", "gpt-4o-mini")
else:
# For code summaries, we use MODEL_CHOICE
llm_chat_model = _get_model_choice()
llm_chat_model = await _get_model_choice()
except Exception as e:
search_logger.warning(f"Failed to get LLM chat model: {e}")
llm_chat_model = "gpt-4o-mini" # Default fallback
@ -888,7 +1235,7 @@ async def add_code_examples_to_supabase(
positions_by_text[text].append(original_indices[k])
# Map successful texts back to their original indices
for embedding, text in zip(valid_embeddings, successful_texts, strict=False):
for embedding, text in zip(valid_embeddings, successful_texts, strict=True):
# Get the next available index for this text (handles duplicates)
if positions_by_text[text]:
orig_idx = positions_by_text[text].popleft() # Original j index in [i, batch_end)
@ -908,7 +1255,7 @@ async def add_code_examples_to_supabase(
# Determine the correct embedding column based on dimension
embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist())
embedding_column = None
if embedding_dim == 768:
embedding_column = "embedding_768"
elif embedding_dim == 1024:
@ -918,10 +1265,12 @@ async def add_code_examples_to_supabase(
elif embedding_dim == 3072:
embedding_column = "embedding_3072"
else:
# Default to closest supported dimension
search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536")
embedding_column = "embedding_1536"
# Skip unsupported dimensions to avoid corrupting the schema
search_logger.error(
f"Unsupported embedding dimension {embedding_dim}; skipping record to prevent column mismatch"
)
continue
batch_data.append({
"url": urls[idx],
"chunk_number": chunk_numbers[idx],
@ -954,9 +1303,7 @@ async def add_code_examples_to_supabase(
f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}"
)
search_logger.info(f"Retrying in {retry_delay} seconds...")
import time
time.sleep(retry_delay)
await asyncio.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Final attempt failed

View File

@ -33,6 +33,12 @@ class AsyncContextManager:
class TestAsyncLLMProviderService:
"""Test suite for async LLM provider service functions"""
@staticmethod
def _make_mock_client():
client = MagicMock()
client.aclose = AsyncMock()
return client
@pytest.fixture(autouse=True)
def clear_cache(self):
"""Clear cache before each test"""
@ -98,7 +104,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
@ -121,7 +127,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
@ -143,7 +149,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
@ -166,7 +172,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client(provider="openai") as client:
@ -194,7 +200,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client(use_embedding_provider=True) as client:
@ -225,7 +231,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
# Should fallback to Ollama instead of raising an error
@ -426,7 +432,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
# First call should hit the credential service
@ -464,7 +470,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
client_ref = None
@ -474,6 +480,7 @@ class TestAsyncLLMProviderService:
# After context manager exits, should still have reference to client
assert client_ref == mock_client
mock_client.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_multiple_providers_in_sequence(self, mock_credential_service):
@ -494,7 +501,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
for config in configs:

View File

@ -104,13 +104,15 @@ class TestCodeExtractionSourceId:
)
# Verify the correct source_id was passed (now with cancellation_check parameter)
mock_extract.assert_called_once_with(
crawl_results,
url_to_full_document,
source_id, # This should be the third argument
None,
None # cancellation_check parameter
)
mock_extract.assert_called_once()
args, kwargs = mock_extract.call_args
assert args[0] == crawl_results
assert args[1] == url_to_full_document
assert args[2] == source_id
assert args[3] is None
assert args[4] is None
if len(args) > 5:
assert args[5] is None
assert result == 5
@pytest.mark.asyncio
@ -174,4 +176,4 @@ class TestCodeExtractionSourceId:
import inspect
source = inspect.getsource(module)
assert "from urllib.parse import urlparse" not in source, \
"Should not import urlparse since we don't extract domain from URL anymore"
"Should not import urlparse since we don't extract domain from URL anymore"