Feat:Openrouter/Anthropic/grok-support (#231)
* Add Anthropic and Grok provider support * feat: Add crucial GPT-5 and reasoning model support for OpenRouter - Add requires_max_completion_tokens() function for GPT-5, o1, o3, Grok-3 series - Add prepare_chat_completion_params() for reasoning model compatibility - Implement max_tokens → max_completion_tokens conversion for reasoning models - Add temperature handling for reasoning models (must be 1.0 default) - Enhanced provider validation and API key security in provider endpoints - Streamlined retry logic (3→2 attempts) for faster issue detection - Add failure tracking and circuit breaker analysis for debugging - Support OpenRouter format detection (openai/gpt-5-nano, openai/o1-mini) - Improved Grok provider empty response handling with structured fallbacks - Enhanced contextual embedding with provider-aware model selection Core provider functionality: - OpenRouter, Grok, Anthropic provider support with full embedding integration - Provider-specific model defaults and validation - Secure API connectivity testing endpoints - Provider context passing for code generation workflows 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * fully working model providers, addressing securtiy and code related concerns, throughly hardening our code * added multiprovider support, embeddings model support, cleaned the pr, need to fix health check, asnyico tasks errors, and contextual embeddings error * fixed contextual embeddings issue * - Added inspect-aware shutdown handling so get_llm_client always closes the underlying AsyncOpenAI / httpx.AsyncClient while the loop is still alive, with defensive logging if shutdown happens late (python/src/server/services/llm_provider_service.py:14, python/src/server/ services/llm_provider_service.py:520). * - Restructured get_llm_client so client creation and usage live in separate try/finally blocks; fallback clients now close without logging spurious Error creating LLM client when downstream code raises (python/src/server/services/llm_provider_service.py:335-556). - Close logic now sanitizes provider names consistently and awaits whichever aclose/close coroutine the SDK exposes, keeping the loop shut down cleanly (python/src/server/services/llm_provider_service.py:530-559). Robust JSON Parsing - Added _extract_json_payload to strip code fences / extra text returned by Ollama before json.loads runs, averting the markdown-induced decode errors you saw in logs (python/src/server/services/storage/code_storage_service.py:40-63). - Swapped the direct parse call for the sanitized payload and emit a debug preview when cleanup alters the content (python/src/server/ services/storage/code_storage_service.py:858-864). * added provider connection support * added provider api key not being configured warning * Updated get_llm_client so missing OpenAI keys automatically fall back to Ollama (matching existing tests) and so unsupported providers still raise the legacy ValueError the suite expects. The fallback now reuses _get_optimal_ollama_instance and rethrows ValueError(OpenAI API key not found and Ollama fallback failed) when it cant connect. Adjusted test_code_extraction_source_id.py to accept the new optional argument on the mocked extractor (and confirm its None when present). * Resolved a few needed code rabbit suggestion - Updated the knowledge API key validation to call create_embedding with the provider argument and removed the hard-coded OpenAI fallback (python/src/server/api_routes/knowledge_api.py). - Broadened embedding provider detection so prefixed OpenRouter/OpenAI model names route through the correct client (python/src/server/ services/embeddings/embedding_service.py, python/src/server/services/llm_provider_service.py). - Removed the duplicate helper definitions from llm_provider_service.py, eliminating the stray docstring that was causing the import-time syntax error. * updated via code rabbit PR review, code rabbit in my IDE found no issues and no nitpicks with the updates! what was done: Credential service now persists the provider under the uppercase key LLM_PROVIDER, matching the read path (no new EMBEDDING_PROVIDER usage introduced). Embedding batch creation stops inserting blank strings, logging failures and skipping invalid items before they ever hit the provider (python/src/server/services/embeddings/embedding_service.py). Contextual embedding prompts use real newline characters everywhereboth when constructing the batch prompt and when parsing the models response (python/src/server/services/embeddings/contextual_embedding_service.py). Embedding provider routing already recognizes OpenRouter-prefixed OpenAI models via is_openai_embedding_model; no further change needed there. Embedding insertion now skips unsupported vector dimensions instead of forcing them into the 1536-column, and the backoff loop uses await asyncio.sleep so we no longer block the event loop (python/src/server/services/storage/code_storage_service.py). RAG settings props were extended to include LLM_INSTANCE_NAME and OLLAMA_EMBEDDING_INSTANCE_NAME, and the debug log no longer prints API-key prefixes (the rest of the TanStack refactor/EMBEDDING_PROVIDER support remains deferred). * test fix * enhanced Openrouters parsing logic to automatically detect reasoning models and parse regardless of json output or not. this commit creates a robust way for archons parsing to work throughly with openrouter automatically, regardless of the model youre using, to ensure proper functionality with out breaking any generation capabilities! --------- Co-authored-by: Chillbruhhh <joshchesser97@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4c910c1471
commit
394ac1befa
@ -9,6 +9,103 @@ import { credentialsService } from '../../services/credentialsService';
|
||||
import OllamaModelDiscoveryModal from './OllamaModelDiscoveryModal';
|
||||
import OllamaModelSelectionModal from './OllamaModelSelectionModal';
|
||||
|
||||
type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter';
|
||||
|
||||
interface ProviderModels {
|
||||
chatModel: string;
|
||||
embeddingModel: string;
|
||||
}
|
||||
|
||||
type ProviderModelMap = Record<ProviderKey, ProviderModels>;
|
||||
|
||||
// Provider model persistence helpers
|
||||
const PROVIDER_MODELS_KEY = 'archon_provider_models';
|
||||
|
||||
const getDefaultModels = (provider: ProviderKey): ProviderModels => {
|
||||
const chatDefaults: Record<ProviderKey, string> = {
|
||||
openai: 'gpt-4o-mini',
|
||||
anthropic: 'claude-3-5-sonnet-20241022',
|
||||
google: 'gemini-1.5-flash',
|
||||
grok: 'grok-3-mini', // Updated to use grok-3-mini as default
|
||||
openrouter: 'openai/gpt-4o-mini',
|
||||
ollama: 'llama3:8b'
|
||||
};
|
||||
|
||||
const embeddingDefaults: Record<ProviderKey, string> = {
|
||||
openai: 'text-embedding-3-small',
|
||||
anthropic: 'text-embedding-3-small', // Fallback to OpenAI
|
||||
google: 'text-embedding-004',
|
||||
grok: 'text-embedding-3-small', // Fallback to OpenAI
|
||||
openrouter: 'text-embedding-3-small',
|
||||
ollama: 'nomic-embed-text'
|
||||
};
|
||||
|
||||
return {
|
||||
chatModel: chatDefaults[provider],
|
||||
embeddingModel: embeddingDefaults[provider]
|
||||
};
|
||||
};
|
||||
|
||||
const saveProviderModels = (providerModels: ProviderModelMap): void => {
|
||||
try {
|
||||
localStorage.setItem(PROVIDER_MODELS_KEY, JSON.stringify(providerModels));
|
||||
} catch (error) {
|
||||
console.error('Failed to save provider models:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const loadProviderModels = (): ProviderModelMap => {
|
||||
try {
|
||||
const saved = localStorage.getItem(PROVIDER_MODELS_KEY);
|
||||
if (saved) {
|
||||
return JSON.parse(saved);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load provider models:', error);
|
||||
}
|
||||
|
||||
// Return defaults for all providers if nothing saved
|
||||
const providers: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok'];
|
||||
const defaultModels: ProviderModelMap = {} as ProviderModelMap;
|
||||
|
||||
providers.forEach(provider => {
|
||||
defaultModels[provider] = getDefaultModels(provider);
|
||||
});
|
||||
|
||||
return defaultModels;
|
||||
};
|
||||
|
||||
// Static color styles mapping (prevents Tailwind JIT purging)
|
||||
const colorStyles: Record<ProviderKey, string> = {
|
||||
openai: 'border-green-500 bg-green-500/10',
|
||||
google: 'border-blue-500 bg-blue-500/10',
|
||||
openrouter: 'border-cyan-500 bg-cyan-500/10',
|
||||
ollama: 'border-purple-500 bg-purple-500/10',
|
||||
anthropic: 'border-orange-500 bg-orange-500/10',
|
||||
grok: 'border-yellow-500 bg-yellow-500/10',
|
||||
};
|
||||
|
||||
const providerAlertStyles: Record<ProviderKey, string> = {
|
||||
openai: 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800 text-green-800 dark:text-green-300',
|
||||
google: 'bg-blue-50 dark:bg-blue-900/20 border-blue-200 dark:border-blue-800 text-blue-800 dark:text-blue-300',
|
||||
openrouter: 'bg-cyan-50 dark:bg-cyan-900/20 border-cyan-200 dark:border-cyan-800 text-cyan-800 dark:text-cyan-300',
|
||||
ollama: 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800 text-purple-800 dark:text-purple-300',
|
||||
anthropic: 'bg-orange-50 dark:bg-orange-900/20 border-orange-200 dark:border-orange-800 text-orange-800 dark:text-orange-300',
|
||||
grok: 'bg-yellow-50 dark:bg-yellow-900/20 border-yellow-200 dark:border-yellow-800 text-yellow-800 dark:text-yellow-300',
|
||||
};
|
||||
|
||||
const providerAlertMessages: Record<ProviderKey, string> = {
|
||||
openai: 'Configure your OpenAI API key in the credentials section to use GPT models.',
|
||||
google: 'Configure your Google API key in the credentials section to use Gemini models.',
|
||||
openrouter: 'Configure your OpenRouter API key in the credentials section to use models.',
|
||||
ollama: 'Configure your Ollama instances in this panel to connect local models.',
|
||||
anthropic: 'Configure your Anthropic API key in the credentials section to use Claude models.',
|
||||
grok: 'Configure your Grok API key in the credentials section to use Grok models.',
|
||||
};
|
||||
|
||||
const isProviderKey = (value: unknown): value is ProviderKey =>
|
||||
typeof value === 'string' && ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok'].includes(value);
|
||||
|
||||
interface RAGSettingsProps {
|
||||
ragSettings: {
|
||||
MODEL_CHOICE: string;
|
||||
@ -19,8 +116,10 @@ interface RAGSettingsProps {
|
||||
USE_RERANKING: boolean;
|
||||
LLM_PROVIDER?: string;
|
||||
LLM_BASE_URL?: string;
|
||||
LLM_INSTANCE_NAME?: string;
|
||||
EMBEDDING_MODEL?: string;
|
||||
OLLAMA_EMBEDDING_URL?: string;
|
||||
OLLAMA_EMBEDDING_INSTANCE_NAME?: string;
|
||||
// Crawling Performance Settings
|
||||
CRAWL_BATCH_SIZE?: number;
|
||||
CRAWL_MAX_CONCURRENT?: number;
|
||||
@ -57,7 +156,10 @@ export const RAGSettings = ({
|
||||
// Model selection modals state
|
||||
const [showLLMModelSelectionModal, setShowLLMModelSelectionModal] = useState(false);
|
||||
const [showEmbeddingModelSelectionModal, setShowEmbeddingModelSelectionModal] = useState(false);
|
||||
|
||||
|
||||
// Provider-specific model persistence state
|
||||
const [providerModels, setProviderModels] = useState<ProviderModelMap>(() => loadProviderModels());
|
||||
|
||||
// Instance configurations
|
||||
const [llmInstanceConfig, setLLMInstanceConfig] = useState({
|
||||
name: '',
|
||||
@ -113,6 +215,25 @@ export const RAGSettings = ({
|
||||
}
|
||||
}, [ragSettings.OLLAMA_EMBEDDING_URL, ragSettings.OLLAMA_EMBEDDING_INSTANCE_NAME]);
|
||||
|
||||
// Provider model persistence effects
|
||||
useEffect(() => {
|
||||
// Update provider models when current models change
|
||||
const currentProvider = ragSettings.LLM_PROVIDER as ProviderKey;
|
||||
if (currentProvider && ragSettings.MODEL_CHOICE && ragSettings.EMBEDDING_MODEL) {
|
||||
setProviderModels(prev => {
|
||||
const updated = {
|
||||
...prev,
|
||||
[currentProvider]: {
|
||||
chatModel: ragSettings.MODEL_CHOICE,
|
||||
embeddingModel: ragSettings.EMBEDDING_MODEL
|
||||
}
|
||||
};
|
||||
saveProviderModels(updated);
|
||||
return updated;
|
||||
});
|
||||
}
|
||||
}, [ragSettings.MODEL_CHOICE, ragSettings.EMBEDDING_MODEL, ragSettings.LLM_PROVIDER]);
|
||||
|
||||
// Load API credentials for status checking
|
||||
useEffect(() => {
|
||||
const loadApiCredentials = async () => {
|
||||
@ -197,58 +318,27 @@ export const RAGSettings = ({
|
||||
}>({});
|
||||
|
||||
// Test connection to external providers
|
||||
const testProviderConnection = async (provider: string, apiKey: string): Promise<boolean> => {
|
||||
const testProviderConnection = async (provider: string): Promise<boolean> => {
|
||||
setProviderConnectionStatus(prev => ({
|
||||
...prev,
|
||||
[provider]: { ...prev[provider], checking: true }
|
||||
}));
|
||||
|
||||
try {
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
// Test OpenAI connection with a simple completion request
|
||||
const openaiResponse = await fetch('https://api.openai.com/v1/models', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
});
|
||||
|
||||
if (openaiResponse.ok) {
|
||||
setProviderConnectionStatus(prev => ({
|
||||
...prev,
|
||||
openai: { connected: true, checking: false, lastChecked: new Date() }
|
||||
}));
|
||||
return true;
|
||||
} else {
|
||||
throw new Error(`OpenAI API returned ${openaiResponse.status}`);
|
||||
}
|
||||
// Use server-side API endpoint for secure connectivity testing
|
||||
const response = await fetch(`/api/providers/${provider}/status`);
|
||||
const result = await response.json();
|
||||
|
||||
case 'google':
|
||||
// Test Google Gemini connection
|
||||
const googleResponse = await fetch(`https://generativelanguage.googleapis.com/v1/models?key=${apiKey}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
});
|
||||
|
||||
if (googleResponse.ok) {
|
||||
setProviderConnectionStatus(prev => ({
|
||||
...prev,
|
||||
google: { connected: true, checking: false, lastChecked: new Date() }
|
||||
}));
|
||||
return true;
|
||||
} else {
|
||||
throw new Error(`Google API returned ${googleResponse.status}`);
|
||||
}
|
||||
const isConnected = result.ok && result.reason === 'connected';
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
setProviderConnectionStatus(prev => ({
|
||||
...prev,
|
||||
[provider]: { connected: isConnected, checking: false, lastChecked: new Date() }
|
||||
}));
|
||||
|
||||
return isConnected;
|
||||
} catch (error) {
|
||||
console.error(`Failed to test ${provider} connection:`, error);
|
||||
console.error(`Error testing ${provider} connection:`, error);
|
||||
setProviderConnectionStatus(prev => ({
|
||||
...prev,
|
||||
[provider]: { connected: false, checking: false, lastChecked: new Date() }
|
||||
@ -260,37 +350,27 @@ export const RAGSettings = ({
|
||||
// Test provider connections when API credentials change
|
||||
useEffect(() => {
|
||||
const testConnections = async () => {
|
||||
const providers = ['openai', 'google'];
|
||||
|
||||
// Test all supported providers
|
||||
const providers = ['openai', 'google', 'anthropic', 'openrouter', 'grok'];
|
||||
|
||||
for (const provider of providers) {
|
||||
const keyName = provider === 'openai' ? 'OPENAI_API_KEY' : 'GOOGLE_API_KEY';
|
||||
const apiKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === keyName);
|
||||
const keyValue = apiKey ? apiCredentials[apiKey] : undefined;
|
||||
|
||||
if (keyValue && keyValue.trim().length > 0) {
|
||||
// Don't test if we've already checked recently (within last 30 seconds)
|
||||
const lastChecked = providerConnectionStatus[provider]?.lastChecked;
|
||||
const now = new Date();
|
||||
const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity;
|
||||
|
||||
if (timeSinceLastCheck > 30000) { // 30 seconds
|
||||
console.log(`🔄 Testing ${provider} connection...`);
|
||||
await testProviderConnection(provider, keyValue);
|
||||
}
|
||||
} else {
|
||||
// No API key, mark as disconnected
|
||||
setProviderConnectionStatus(prev => ({
|
||||
...prev,
|
||||
[provider]: { connected: false, checking: false, lastChecked: new Date() }
|
||||
}));
|
||||
// Don't test if we've already checked recently (within last 30 seconds)
|
||||
const lastChecked = providerConnectionStatus[provider]?.lastChecked;
|
||||
const now = new Date();
|
||||
const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity;
|
||||
|
||||
if (timeSinceLastCheck > 30000) { // 30 seconds
|
||||
console.log(`🔄 Testing ${provider} connection...`);
|
||||
await testProviderConnection(provider);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Only test if we have credentials loaded
|
||||
if (Object.keys(apiCredentials).length > 0) {
|
||||
testConnections();
|
||||
}
|
||||
// Test connections periodically (every 60 seconds)
|
||||
testConnections();
|
||||
const interval = setInterval(testConnections, 60000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [apiCredentials]); // Test when credentials change
|
||||
|
||||
// Ref to track if initial test has been run (will be used after function definitions)
|
||||
@ -662,24 +742,41 @@ export const RAGSettings = ({
|
||||
if (llmStatus.online || embeddingStatus.online) return 'partial';
|
||||
return 'missing';
|
||||
case 'anthropic':
|
||||
// Check if Anthropic API key is configured (case insensitive)
|
||||
const anthropicKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'ANTHROPIC_API_KEY');
|
||||
const hasAnthropicKey = anthropicKey && apiCredentials[anthropicKey] && apiCredentials[anthropicKey].trim().length > 0;
|
||||
return hasAnthropicKey ? 'configured' : 'missing';
|
||||
// Use server-side connection status
|
||||
const anthropicConnected = providerConnectionStatus['anthropic']?.connected || false;
|
||||
const anthropicChecking = providerConnectionStatus['anthropic']?.checking || false;
|
||||
if (anthropicChecking) return 'partial';
|
||||
return anthropicConnected ? 'configured' : 'missing';
|
||||
case 'grok':
|
||||
// Check if Grok API key is configured (case insensitive)
|
||||
const grokKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'GROK_API_KEY');
|
||||
const hasGrokKey = grokKey && apiCredentials[grokKey] && apiCredentials[grokKey].trim().length > 0;
|
||||
return hasGrokKey ? 'configured' : 'missing';
|
||||
// Use server-side connection status
|
||||
const grokConnected = providerConnectionStatus['grok']?.connected || false;
|
||||
const grokChecking = providerConnectionStatus['grok']?.checking || false;
|
||||
if (grokChecking) return 'partial';
|
||||
return grokConnected ? 'configured' : 'missing';
|
||||
case 'openrouter':
|
||||
// Check if OpenRouter API key is configured (case insensitive)
|
||||
const openRouterKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'OPENROUTER_API_KEY');
|
||||
const hasOpenRouterKey = openRouterKey && apiCredentials[openRouterKey] && apiCredentials[openRouterKey].trim().length > 0;
|
||||
return hasOpenRouterKey ? 'configured' : 'missing';
|
||||
// Use server-side connection status
|
||||
const openRouterConnected = providerConnectionStatus['openrouter']?.connected || false;
|
||||
const openRouterChecking = providerConnectionStatus['openrouter']?.checking || false;
|
||||
if (openRouterChecking) return 'partial';
|
||||
return openRouterConnected ? 'configured' : 'missing';
|
||||
default:
|
||||
return 'missing';
|
||||
}
|
||||
};;
|
||||
};
|
||||
|
||||
const selectedProviderKey = isProviderKey(ragSettings.LLM_PROVIDER)
|
||||
? (ragSettings.LLM_PROVIDER as ProviderKey)
|
||||
: undefined;
|
||||
const selectedProviderStatus = selectedProviderKey ? getProviderStatus(selectedProviderKey) : undefined;
|
||||
const shouldShowProviderAlert = Boolean(
|
||||
selectedProviderKey && selectedProviderStatus === 'missing'
|
||||
);
|
||||
const providerAlertClassName = shouldShowProviderAlert && selectedProviderKey
|
||||
? providerAlertStyles[selectedProviderKey]
|
||||
: '';
|
||||
const providerAlertMessage = shouldShowProviderAlert && selectedProviderKey
|
||||
? providerAlertMessages[selectedProviderKey]
|
||||
: '';
|
||||
|
||||
// Test Ollama connectivity when Settings page loads (scenario 4: page load)
|
||||
// This useEffect is placed after function definitions to ensure access to manualTestConnection
|
||||
@ -750,55 +847,32 @@ export const RAGSettings = ({
|
||||
{[
|
||||
{ key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' },
|
||||
{ key: 'google', name: 'Google', logo: '/img/google-logo.svg', color: 'blue' },
|
||||
{ key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' },
|
||||
{ key: 'ollama', name: 'Ollama', logo: '/img/Ollama.png', color: 'purple' },
|
||||
{ key: 'anthropic', name: 'Anthropic', logo: '/img/claude-logo.svg', color: 'orange' },
|
||||
{ key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' },
|
||||
{ key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' }
|
||||
{ key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' }
|
||||
].map(provider => (
|
||||
<button
|
||||
key={provider.key}
|
||||
type="button"
|
||||
onClick={() => {
|
||||
// Get saved models for this provider, or use defaults
|
||||
const providerKey = provider.key as ProviderKey;
|
||||
const savedModels = providerModels[providerKey] || getDefaultModels(providerKey);
|
||||
|
||||
const updatedSettings = {
|
||||
...ragSettings,
|
||||
LLM_PROVIDER: provider.key
|
||||
LLM_PROVIDER: providerKey,
|
||||
MODEL_CHOICE: savedModels.chatModel,
|
||||
EMBEDDING_MODEL: savedModels.embeddingModel
|
||||
};
|
||||
|
||||
// Set models to provider-appropriate defaults when switching providers
|
||||
// This ensures both LLM and embedding models switch when provider changes
|
||||
const getDefaultChatModel = (provider: string): string => {
|
||||
switch (provider) {
|
||||
case 'openai': return 'gpt-4o-mini';
|
||||
case 'anthropic': return 'claude-3-5-sonnet-20241022';
|
||||
case 'google': return 'gemini-1.5-flash';
|
||||
case 'grok': return 'grok-2-latest';
|
||||
case 'ollama': return '';
|
||||
case 'openrouter': return 'anthropic/claude-3.5-sonnet';
|
||||
default: return 'gpt-4o-mini';
|
||||
}
|
||||
};
|
||||
|
||||
const getDefaultEmbeddingModel = (provider: string): string => {
|
||||
switch (provider) {
|
||||
case 'openai': return 'text-embedding-3-small';
|
||||
case 'google': return 'text-embedding-004';
|
||||
case 'ollama': return '';
|
||||
case 'openrouter': return 'text-embedding-3-small';
|
||||
case 'anthropic':
|
||||
case 'grok':
|
||||
default: return 'text-embedding-3-small';
|
||||
}
|
||||
};
|
||||
|
||||
updatedSettings.MODEL_CHOICE = getDefaultChatModel(provider.key);
|
||||
updatedSettings.EMBEDDING_MODEL = getDefaultEmbeddingModel(provider.key);
|
||||
|
||||
|
||||
setRagSettings(updatedSettings);
|
||||
}}
|
||||
className={`
|
||||
relative p-3 rounded-lg border-2 transition-all duration-200 text-center
|
||||
${ragSettings.LLM_PROVIDER === provider.key
|
||||
? `border-${provider.color}-500 bg-${provider.color}-500/10 shadow-[0_0_15px_rgba(34,197,94,0.3)]`
|
||||
? `${colorStyles[provider.key as ProviderKey]} shadow-[0_0_15px_rgba(34,197,94,0.3)]`
|
||||
: 'border-gray-300 dark:border-gray-600 hover:border-gray-400 dark:hover:border-gray-500'
|
||||
}
|
||||
hover:scale-105 active:scale-95
|
||||
@ -813,8 +887,8 @@ export const RAGSettings = ({
|
||||
: ''
|
||||
}`}
|
||||
/>
|
||||
<div className={`text-sm font-medium text-gray-700 dark:text-gray-300 ${
|
||||
provider.key === 'openrouter' ? 'text-center' : ''
|
||||
<div className={`font-medium text-gray-700 dark:text-gray-300 text-center ${
|
||||
provider.key === 'openrouter' ? 'text-xs' : 'text-sm'
|
||||
}`}>
|
||||
{provider.name}
|
||||
</div>
|
||||
@ -842,13 +916,6 @@ export const RAGSettings = ({
|
||||
);
|
||||
}
|
||||
})()}
|
||||
{(provider.key === 'anthropic' || provider.key === 'grok' || provider.key === 'openrouter') && (
|
||||
<div className="absolute inset-0 bg-black/20 rounded-lg flex items-center justify-center">
|
||||
<div className="bg-yellow-500/80 text-black text-xs font-bold px-2 py-1 rounded transform -rotate-12">
|
||||
Coming Soon
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
@ -1223,19 +1290,9 @@ export const RAGSettings = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{ragSettings.LLM_PROVIDER === 'anthropic' && (
|
||||
<div className="p-4 bg-orange-50 dark:bg-orange-900/20 border border-orange-200 dark:border-orange-800 rounded-lg mb-4">
|
||||
<p className="text-sm text-orange-800 dark:text-orange-300">
|
||||
Configure your Anthropic API key in the credentials section to use Claude models.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{ragSettings.LLM_PROVIDER === 'groq' && (
|
||||
<div className="p-4 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg mb-4">
|
||||
<p className="text-sm text-yellow-800 dark:text-yellow-300">
|
||||
Groq provides fast inference with Llama, Mixtral, and Gemma models.
|
||||
</p>
|
||||
{shouldShowProviderAlert && (
|
||||
<div className={`p-4 border rounded-lg mb-4 ${providerAlertClassName}`}>
|
||||
<p className="text-sm">{providerAlertMessage}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -1853,94 +1910,56 @@ export const RAGSettings = ({
|
||||
function getDisplayedChatModel(ragSettings: any): string {
|
||||
const provider = ragSettings.LLM_PROVIDER || 'openai';
|
||||
const modelChoice = ragSettings.MODEL_CHOICE;
|
||||
|
||||
// Check if the stored model is appropriate for the current provider
|
||||
const isModelAppropriate = (model: string, provider: string): boolean => {
|
||||
if (!model) return false;
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return model.startsWith('gpt-') || model.startsWith('o1-') || model.includes('text-davinci') || model.includes('text-embedding');
|
||||
case 'anthropic':
|
||||
return model.startsWith('claude-');
|
||||
case 'google':
|
||||
return model.startsWith('gemini-') || model.startsWith('text-embedding-');
|
||||
case 'grok':
|
||||
return model.startsWith('grok-');
|
||||
case 'ollama':
|
||||
return !model.startsWith('gpt-') && !model.startsWith('claude-') && !model.startsWith('gemini-') && !model.startsWith('grok-');
|
||||
case 'openrouter':
|
||||
return model.includes('/') || model.startsWith('anthropic/') || model.startsWith('openai/');
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Use stored model if it's appropriate for the provider, otherwise use default
|
||||
const useStoredModel = modelChoice && isModelAppropriate(modelChoice, provider);
|
||||
|
||||
|
||||
// Always prioritize user input to allow editing
|
||||
if (modelChoice !== undefined && modelChoice !== null) {
|
||||
return modelChoice;
|
||||
}
|
||||
|
||||
// Only use defaults when there's no stored value
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return useStoredModel ? modelChoice : 'gpt-4o-mini';
|
||||
return 'gpt-4o-mini';
|
||||
case 'anthropic':
|
||||
return useStoredModel ? modelChoice : 'claude-3-5-sonnet-20241022';
|
||||
return 'claude-3-5-sonnet-20241022';
|
||||
case 'google':
|
||||
return useStoredModel ? modelChoice : 'gemini-1.5-flash';
|
||||
return 'gemini-1.5-flash';
|
||||
case 'grok':
|
||||
return useStoredModel ? modelChoice : 'grok-2-latest';
|
||||
return 'grok-3-mini';
|
||||
case 'ollama':
|
||||
return useStoredModel ? modelChoice : '';
|
||||
return '';
|
||||
case 'openrouter':
|
||||
return useStoredModel ? modelChoice : 'anthropic/claude-3.5-sonnet';
|
||||
return 'anthropic/claude-3.5-sonnet';
|
||||
default:
|
||||
return useStoredModel ? modelChoice : 'gpt-4o-mini';
|
||||
return 'gpt-4o-mini';
|
||||
}
|
||||
}
|
||||
|
||||
function getDisplayedEmbeddingModel(ragSettings: any): string {
|
||||
const provider = ragSettings.LLM_PROVIDER || 'openai';
|
||||
const embeddingModel = ragSettings.EMBEDDING_MODEL;
|
||||
|
||||
// Check if the stored embedding model is appropriate for the current provider
|
||||
const isEmbeddingModelAppropriate = (model: string, provider: string): boolean => {
|
||||
if (!model) return false;
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return model.startsWith('text-embedding-') || model.includes('ada-');
|
||||
case 'anthropic':
|
||||
return false; // Claude doesn't provide embedding models
|
||||
case 'google':
|
||||
return model.startsWith('text-embedding-') || model.startsWith('textembedding-') || model.includes('embedding');
|
||||
case 'grok':
|
||||
return false; // Grok doesn't provide embedding models
|
||||
case 'ollama':
|
||||
return !model.startsWith('text-embedding-') || model.includes('embed') || model.includes('arctic');
|
||||
case 'openrouter':
|
||||
return model.startsWith('text-embedding-') || model.includes('/');
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Use stored model if it's appropriate for the provider, otherwise use default
|
||||
const useStoredModel = embeddingModel && isEmbeddingModelAppropriate(embeddingModel, provider);
|
||||
|
||||
|
||||
// Always prioritize user input to allow editing
|
||||
if (embeddingModel !== undefined && embeddingModel !== null && embeddingModel !== '') {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
// Provide appropriate defaults based on LLM provider
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return useStoredModel ? embeddingModel : 'text-embedding-3-small';
|
||||
case 'anthropic':
|
||||
return 'Not available - Claude does not provide embedding models';
|
||||
return 'text-embedding-3-small';
|
||||
case 'google':
|
||||
return useStoredModel ? embeddingModel : 'text-embedding-004';
|
||||
case 'grok':
|
||||
return 'Not available - Grok does not provide embedding models';
|
||||
return 'text-embedding-004';
|
||||
case 'ollama':
|
||||
return useStoredModel ? embeddingModel : '';
|
||||
return '';
|
||||
case 'openrouter':
|
||||
return useStoredModel ? embeddingModel : 'text-embedding-3-small';
|
||||
return 'text-embedding-3-small'; // Default to OpenAI embedding for OpenRouter
|
||||
case 'anthropic':
|
||||
return 'text-embedding-3-small'; // Use OpenAI embeddings with Claude
|
||||
case 'grok':
|
||||
return 'text-embedding-3-small'; // Use OpenAI embeddings with Grok
|
||||
default:
|
||||
return useStoredModel ? embeddingModel : 'text-embedding-3-small';
|
||||
return 'text-embedding-3-small';
|
||||
}
|
||||
}
|
||||
|
||||
@ -2035,4 +2054,4 @@ const CustomCheckbox = ({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
@ -14,6 +14,7 @@ from .internal_api import router as internal_router
|
||||
from .knowledge_api import router as knowledge_router
|
||||
from .mcp_api import router as mcp_router
|
||||
from .projects_api import router as projects_router
|
||||
from .providers_api import router as providers_router
|
||||
from .settings_api import router as settings_router
|
||||
|
||||
__all__ = [
|
||||
@ -23,4 +24,5 @@ __all__ = [
|
||||
"projects_router",
|
||||
"agent_chat_router",
|
||||
"internal_router",
|
||||
"providers_router",
|
||||
]
|
||||
|
||||
@ -18,6 +18,8 @@ from urllib.parse import urlparse
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Basic validation - simplified inline version
|
||||
|
||||
# Import unified logging
|
||||
from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
|
||||
from ..services.crawler_manager import get_crawler
|
||||
@ -62,26 +64,59 @@ async def _validate_provider_api_key(provider: str = None) -> None:
|
||||
logger.info("🔑 Starting API key validation...")
|
||||
|
||||
try:
|
||||
# Basic provider validation
|
||||
if not provider:
|
||||
provider = "openai"
|
||||
else:
|
||||
# Simple provider validation
|
||||
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
|
||||
if provider not in allowed_providers:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid provider name",
|
||||
"message": f"Provider '{provider}' not supported",
|
||||
"error_type": "validation_error"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"🔑 Testing {provider.title()} API key with minimal embedding request...")
|
||||
|
||||
# Test API key with minimal embedding request - this will fail if key is invalid
|
||||
from ..services.embeddings.embedding_service import create_embedding
|
||||
test_result = await create_embedding(text="test")
|
||||
|
||||
if not test_result:
|
||||
logger.error(f"❌ {provider.title()} API key validation failed - no embedding returned")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Invalid {provider.title()} API key",
|
||||
"message": f"Please verify your {provider.title()} API key in Settings.",
|
||||
"error_type": "authentication_failed",
|
||||
"provider": provider
|
||||
}
|
||||
)
|
||||
# Basic sanitization for logging
|
||||
safe_provider = provider[:20] # Limit length
|
||||
logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...")
|
||||
|
||||
try:
|
||||
# Test API key with minimal embedding request using provider-scoped configuration
|
||||
from ..services.embeddings.embedding_service import create_embedding
|
||||
|
||||
test_result = await create_embedding(text="test", provider=provider)
|
||||
|
||||
if not test_result:
|
||||
logger.error(
|
||||
f"❌ {provider.title()} API key validation failed - no embedding returned"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Invalid {provider.title()} API key",
|
||||
"message": f"Please verify your {provider.title()} API key in Settings.",
|
||||
"error_type": "authentication_failed",
|
||||
"provider": provider,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"❌ {provider.title()} API key validation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Invalid {provider.title()} API key",
|
||||
"message": f"Please verify your {provider.title()} API key in Settings. Error: {str(e)[:100]}",
|
||||
"error_type": "authentication_failed",
|
||||
"provider": provider,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"✅ {provider.title()} API key validation successful")
|
||||
|
||||
|
||||
154
python/src/server/api_routes/providers_api.py
Normal file
154
python/src/server/api_routes/providers_api.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""
|
||||
Provider status API endpoints for testing connectivity
|
||||
|
||||
Handles server-side provider connectivity testing without exposing API keys to frontend.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Path
|
||||
|
||||
from ..config.logfire_config import logfire
|
||||
from ..services.credential_service import credential_service
|
||||
# Provider validation - simplified inline version
|
||||
|
||||
router = APIRouter(prefix="/api/providers", tags=["providers"])
|
||||
|
||||
|
||||
async def test_openai_connection(api_key: str) -> bool:
|
||||
"""Test OpenAI API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"OpenAI connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_google_connection(api_key: str) -> bool:
|
||||
"""Test Google AI API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://generativelanguage.googleapis.com/v1/models",
|
||||
headers={"x-goog-api-key": api_key}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
logfire.warning("Google AI connectivity test failed")
|
||||
return False
|
||||
|
||||
|
||||
async def test_anthropic_connection(api_key: str) -> bool:
|
||||
"""Test Anthropic API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"Anthropic connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_openrouter_connection(api_key: str) -> bool:
|
||||
"""Test OpenRouter API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"OpenRouter connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_grok_connection(api_key: str) -> bool:
|
||||
"""Test Grok API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.x.ai/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"Grok connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
PROVIDER_TESTERS = {
|
||||
"openai": test_openai_connection,
|
||||
"google": test_google_connection,
|
||||
"anthropic": test_anthropic_connection,
|
||||
"openrouter": test_openrouter_connection,
|
||||
"grok": test_grok_connection,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{provider}/status")
|
||||
async def get_provider_status(
|
||||
provider: str = Path(
|
||||
...,
|
||||
description="Provider name to test connectivity for",
|
||||
regex="^[a-z0-9_]+$",
|
||||
max_length=20
|
||||
)
|
||||
):
|
||||
"""Test provider connectivity using server-side API key (secure)"""
|
||||
try:
|
||||
# Basic provider validation
|
||||
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
|
||||
if provider not in allowed_providers:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid provider '{provider}'. Allowed providers: {sorted(allowed_providers)}"
|
||||
)
|
||||
|
||||
# Basic sanitization for logging
|
||||
safe_provider = provider[:20] # Limit length
|
||||
logfire.info(f"Testing {safe_provider} connectivity server-side")
|
||||
|
||||
if provider not in PROVIDER_TESTERS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider '{provider}' not supported for connectivity testing"
|
||||
)
|
||||
|
||||
# Get API key server-side (never expose to client)
|
||||
key_name = f"{provider.upper()}_API_KEY"
|
||||
api_key = await credential_service.get_credential(key_name, decrypt=True)
|
||||
|
||||
if not api_key or not isinstance(api_key, str) or not api_key.strip():
|
||||
logfire.info(f"No API key configured for {safe_provider}")
|
||||
return {"ok": False, "reason": "no_key"}
|
||||
|
||||
# Test connectivity using server-side key
|
||||
tester = PROVIDER_TESTERS[provider]
|
||||
is_connected = await tester(api_key)
|
||||
|
||||
logfire.info(f"{safe_provider} connectivity test result: {is_connected}")
|
||||
return {
|
||||
"ok": is_connected,
|
||||
"reason": "connected" if is_connected else "connection_failed",
|
||||
"provider": provider # Echo back validated provider name
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (they're already properly formatted)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Basic error sanitization for logging
|
||||
safe_error = str(e)[:100] # Limit length
|
||||
logfire.error(f"Error testing {provider[:20]} connectivity: {safe_error}")
|
||||
raise HTTPException(status_code=500, detail={"error": "Internal server error during connectivity test"})
|
||||
@ -26,6 +26,7 @@ from .api_routes.mcp_api import router as mcp_router
|
||||
from .api_routes.ollama_api import router as ollama_router
|
||||
from .api_routes.progress_api import router as progress_router
|
||||
from .api_routes.projects_api import router as projects_router
|
||||
from .api_routes.providers_api import router as providers_router
|
||||
|
||||
# Import modular API routers
|
||||
from .api_routes.settings_api import router as settings_router
|
||||
@ -186,6 +187,7 @@ app.include_router(progress_router)
|
||||
app.include_router(agent_chat_router)
|
||||
app.include_router(internal_router)
|
||||
app.include_router(bug_report_router)
|
||||
app.include_router(providers_router)
|
||||
|
||||
|
||||
# Root endpoint
|
||||
|
||||
@ -139,6 +139,7 @@ class CodeExtractionService:
|
||||
source_id: str,
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Extract code examples from crawled documents and store them.
|
||||
@ -204,7 +205,7 @@ class CodeExtractionService:
|
||||
|
||||
# Generate summaries for code blocks
|
||||
summary_results = await self._generate_code_summaries(
|
||||
all_code_blocks, summary_callback, cancellation_check
|
||||
all_code_blocks, summary_callback, cancellation_check, provider
|
||||
)
|
||||
|
||||
# Prepare code examples for storage
|
||||
@ -223,7 +224,7 @@ class CodeExtractionService:
|
||||
|
||||
# Store code examples in database
|
||||
return await self._store_code_examples(
|
||||
storage_data, url_to_full_document, storage_callback
|
||||
storage_data, url_to_full_document, storage_callback, provider
|
||||
)
|
||||
|
||||
async def _extract_code_blocks_from_documents(
|
||||
@ -1523,6 +1524,7 @@ class CodeExtractionService:
|
||||
all_code_blocks: list[dict[str, Any]],
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Generate summaries for all code blocks.
|
||||
@ -1587,7 +1589,7 @@ class CodeExtractionService:
|
||||
|
||||
try:
|
||||
results = await generate_code_summaries_batch(
|
||||
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback
|
||||
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback, provider=provider
|
||||
)
|
||||
|
||||
# Ensure all results are valid dicts
|
||||
@ -1667,6 +1669,7 @@ class CodeExtractionService:
|
||||
storage_data: dict[str, list[Any]],
|
||||
url_to_full_document: dict[str, str],
|
||||
progress_callback: Callable | None = None,
|
||||
provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Store code examples in the database.
|
||||
@ -1709,7 +1712,7 @@ class CodeExtractionService:
|
||||
batch_size=20,
|
||||
url_to_full_document=url_to_full_document,
|
||||
progress_callback=storage_progress_callback,
|
||||
provider=None, # Use configured provider
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Report completion of code extraction/storage phase
|
||||
|
||||
@ -475,12 +475,24 @@ class CrawlingService:
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract provider from request or use credential service default
|
||||
provider = request.get("provider")
|
||||
if not provider:
|
||||
try:
|
||||
from ..credential_service import credential_service
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
provider = provider_config.get("provider", "openai")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
|
||||
provider = "openai"
|
||||
|
||||
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
|
||||
crawl_results,
|
||||
storage_results["url_to_full_document"],
|
||||
storage_results["source_id"],
|
||||
code_progress_callback,
|
||||
self._check_cancellation,
|
||||
provider,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Code extraction failed, continue crawl with warning
|
||||
|
||||
@ -351,6 +351,7 @@ class DocumentStorageOperations:
|
||||
source_id: str,
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Extract code examples from crawled documents and store them.
|
||||
@ -361,12 +362,13 @@ class DocumentStorageOperations:
|
||||
source_id: The unique source_id for all documents
|
||||
progress_callback: Optional callback for progress updates
|
||||
cancellation_check: Optional function to check for cancellation
|
||||
provider: Optional LLM provider to use for code summaries
|
||||
|
||||
Returns:
|
||||
Number of code examples stored
|
||||
"""
|
||||
result = await self.code_extraction_service.extract_and_store_code_examples(
|
||||
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check
|
||||
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -36,6 +36,44 @@ class CredentialItem:
|
||||
description: str | None = None
|
||||
|
||||
|
||||
def _detect_embedding_provider_from_model(embedding_model: str) -> str:
|
||||
"""
|
||||
Detect the appropriate embedding provider based on model name.
|
||||
|
||||
Args:
|
||||
embedding_model: The embedding model name
|
||||
|
||||
Returns:
|
||||
Provider name: 'google', 'openai', or 'openai' (default)
|
||||
"""
|
||||
if not embedding_model:
|
||||
return "openai" # Default
|
||||
|
||||
model_lower = embedding_model.lower()
|
||||
|
||||
# Google embedding models
|
||||
google_patterns = [
|
||||
"text-embedding-004",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding",
|
||||
"gemini-embedding",
|
||||
"multimodalembedding"
|
||||
]
|
||||
|
||||
if any(pattern in model_lower for pattern in google_patterns):
|
||||
return "google"
|
||||
|
||||
# OpenAI embedding models (and default for unknown)
|
||||
openai_patterns = [
|
||||
"text-embedding-ada-002",
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large"
|
||||
]
|
||||
|
||||
# Default to OpenAI for OpenAI models or unknown models
|
||||
return "openai"
|
||||
|
||||
|
||||
class CredentialService:
|
||||
"""Service for managing application credentials and configuration."""
|
||||
|
||||
@ -239,6 +277,14 @@ class CredentialService:
|
||||
self._rag_cache_timestamp = None
|
||||
logger.debug(f"Invalidated RAG settings cache due to update of {key}")
|
||||
|
||||
# Also invalidate provider service cache to ensure immediate effect
|
||||
try:
|
||||
from .llm_provider_service import clear_provider_cache
|
||||
clear_provider_cache()
|
||||
logger.debug("Also cleared LLM provider service cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear provider service cache: {e}")
|
||||
|
||||
# Also invalidate LLM provider service cache for provider config
|
||||
try:
|
||||
from . import llm_provider_service
|
||||
@ -281,6 +327,14 @@ class CredentialService:
|
||||
self._rag_cache_timestamp = None
|
||||
logger.debug(f"Invalidated RAG settings cache due to deletion of {key}")
|
||||
|
||||
# Also invalidate provider service cache to ensure immediate effect
|
||||
try:
|
||||
from .llm_provider_service import clear_provider_cache
|
||||
clear_provider_cache()
|
||||
logger.debug("Also cleared LLM provider service cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear provider service cache: {e}")
|
||||
|
||||
# Also invalidate LLM provider service cache for provider config
|
||||
try:
|
||||
from . import llm_provider_service
|
||||
@ -419,8 +473,33 @@ class CredentialService:
|
||||
# Get RAG strategy settings (where UI saves provider selection)
|
||||
rag_settings = await self.get_credentials_by_category("rag_strategy")
|
||||
|
||||
# Get the selected provider
|
||||
provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
# Get the selected provider based on service type
|
||||
if service_type == "embedding":
|
||||
# Get the LLM provider setting to determine embedding provider
|
||||
llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small")
|
||||
|
||||
# Determine embedding provider based on LLM provider
|
||||
if llm_provider == "google":
|
||||
provider = "google"
|
||||
elif llm_provider == "ollama":
|
||||
provider = "ollama"
|
||||
elif llm_provider == "openrouter":
|
||||
# OpenRouter supports both OpenAI and Google embedding models
|
||||
provider = _detect_embedding_provider_from_model(embedding_model)
|
||||
elif llm_provider in ["anthropic", "grok"]:
|
||||
# Anthropic and Grok support both OpenAI and Google embedding models
|
||||
provider = _detect_embedding_provider_from_model(embedding_model)
|
||||
else:
|
||||
# Default case (openai, or unknown providers)
|
||||
provider = "openai"
|
||||
|
||||
logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'")
|
||||
else:
|
||||
provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
# Ensure provider is a valid string, not a boolean or other type
|
||||
if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"):
|
||||
provider = "openai"
|
||||
|
||||
# Get API key for this provider
|
||||
api_key = await self._get_provider_api_key(provider)
|
||||
@ -464,6 +543,9 @@ class CredentialService:
|
||||
key_mapping = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"grok": "GROK_API_KEY",
|
||||
"ollama": None, # No API key needed
|
||||
}
|
||||
|
||||
@ -478,6 +560,12 @@ class CredentialService:
|
||||
return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1")
|
||||
elif provider == "google":
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
elif provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
elif provider == "anthropic":
|
||||
return "https://api.anthropic.com/v1"
|
||||
elif provider == "grok":
|
||||
return "https://api.x.ai/v1"
|
||||
return None # Use default for OpenAI
|
||||
|
||||
async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool:
|
||||
@ -485,7 +573,7 @@ class CredentialService:
|
||||
try:
|
||||
# For now, we'll update the RAG strategy settings
|
||||
return await self.set_credential(
|
||||
"llm_provider",
|
||||
"LLM_PROVIDER",
|
||||
provider,
|
||||
category="rag_strategy",
|
||||
description=f"Active {service_type} provider",
|
||||
|
||||
@ -10,7 +10,13 @@ import os
|
||||
import openai
|
||||
|
||||
from ...config.logfire_config import search_logger
|
||||
from ..llm_provider_service import get_llm_client
|
||||
from ..credential_service import credential_service
|
||||
from ..llm_provider_service import (
|
||||
extract_message_text,
|
||||
get_llm_client,
|
||||
prepare_chat_completion_params,
|
||||
requires_max_completion_tokens,
|
||||
)
|
||||
from ..threading_service import get_threading_service
|
||||
|
||||
|
||||
@ -32,8 +38,6 @@ async def generate_contextual_embedding(
|
||||
"""
|
||||
# Model choice is a RAG setting, get from credential service
|
||||
try:
|
||||
from ...services.credential_service import credential_service
|
||||
|
||||
model_choice = await credential_service.get_credential("MODEL_CHOICE", "gpt-4.1-nano")
|
||||
except Exception as e:
|
||||
# Fallback to environment variable or default
|
||||
@ -65,20 +69,25 @@ Please give a short succinct context to situate this chunk within the overall do
|
||||
# Get model from provider configuration
|
||||
model = await _get_model_choice(provider)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that provides concise contextual information.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
)
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1200 if requires_max_completion_tokens(model) else 200, # Much more tokens for reasoning models (GPT-5 needs extra for reasoning process)
|
||||
}
|
||||
final_params = prepare_chat_completion_params(model, params)
|
||||
response = await client.chat.completions.create(**final_params)
|
||||
|
||||
context = response.choices[0].message.content.strip()
|
||||
choice = response.choices[0] if response.choices else None
|
||||
context, _, _ = extract_message_text(choice)
|
||||
context = context.strip()
|
||||
contextual_text = f"{context}\n---\n{chunk}"
|
||||
|
||||
return contextual_text, True
|
||||
@ -111,7 +120,7 @@ async def process_chunk_with_context(
|
||||
|
||||
|
||||
async def _get_model_choice(provider: str | None = None) -> str:
|
||||
"""Get model choice from credential service."""
|
||||
"""Get model choice from credential service with centralized defaults."""
|
||||
from ..credential_service import credential_service
|
||||
|
||||
# Get the active provider configuration
|
||||
@ -119,31 +128,36 @@ async def _get_model_choice(provider: str | None = None) -> str:
|
||||
model = provider_config.get("chat_model", "").strip() # Strip whitespace
|
||||
provider_name = provider_config.get("provider", "openai")
|
||||
|
||||
# Handle empty model case - fallback to provider-specific defaults or explicit config
|
||||
# Handle empty model case - use centralized defaults
|
||||
if not model:
|
||||
search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic")
|
||||
|
||||
search_logger.warning(f"chat_model is empty for provider {provider_name}, using centralized defaults")
|
||||
|
||||
# Special handling for Ollama to check specific credential
|
||||
if provider_name == "ollama":
|
||||
# Try to get OLLAMA_CHAT_MODEL specifically
|
||||
try:
|
||||
ollama_model = await credential_service.get_credential("OLLAMA_CHAT_MODEL")
|
||||
if ollama_model and ollama_model.strip():
|
||||
model = ollama_model.strip()
|
||||
search_logger.info(f"Using OLLAMA_CHAT_MODEL fallback: {model}")
|
||||
else:
|
||||
# Use a sensible Ollama default
|
||||
# Use default for Ollama
|
||||
model = "llama3.2:latest"
|
||||
search_logger.info(f"Using Ollama default model: {model}")
|
||||
search_logger.info(f"Using Ollama default: {model}")
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error getting OLLAMA_CHAT_MODEL: {e}")
|
||||
model = "llama3.2:latest"
|
||||
search_logger.info(f"Using Ollama fallback model: {model}")
|
||||
elif provider_name == "google":
|
||||
model = "gemini-1.5-flash"
|
||||
search_logger.info(f"Using Ollama fallback: {model}")
|
||||
else:
|
||||
# OpenAI or other providers
|
||||
model = "gpt-4o-mini"
|
||||
|
||||
# Use provider-specific defaults
|
||||
provider_defaults = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"openrouter": "anthropic/claude-3.5-sonnet",
|
||||
"google": "gemini-1.5-flash",
|
||||
"anthropic": "claude-3-5-haiku-20241022",
|
||||
"grok": "grok-3-mini"
|
||||
}
|
||||
model = provider_defaults.get(provider_name, "gpt-4o-mini")
|
||||
search_logger.debug(f"Using default model for provider {provider_name}: {model}")
|
||||
search_logger.debug(f"Using model from credential service: {model}")
|
||||
|
||||
return model
|
||||
@ -174,38 +188,48 @@ async def generate_contextual_embeddings_batch(
|
||||
model_choice = await _get_model_choice(provider)
|
||||
|
||||
# Build batch prompt for ALL chunks at once
|
||||
batch_prompt = (
|
||||
"Process the following chunks and provide contextual information for each:\\n\\n"
|
||||
)
|
||||
batch_prompt = "Process the following chunks and provide contextual information for each:\n\n"
|
||||
|
||||
for i, (doc, chunk) in enumerate(zip(full_documents, chunks, strict=False)):
|
||||
# Use only 2000 chars of document context to save tokens
|
||||
doc_preview = doc[:2000] if len(doc) > 2000 else doc
|
||||
batch_prompt += f"CHUNK {i + 1}:\\n"
|
||||
batch_prompt += f"<document_preview>\\n{doc_preview}\\n</document_preview>\\n"
|
||||
batch_prompt += f"<chunk>\\n{chunk[:500]}\\n</chunk>\\n\\n" # Limit chunk preview
|
||||
batch_prompt += f"CHUNK {i + 1}:\n"
|
||||
batch_prompt += f"<document_preview>\n{doc_preview}\n</document_preview>\n"
|
||||
batch_prompt += f"<chunk>\n{chunk[:500]}\n</chunk>\n\n" # Limit chunk preview
|
||||
|
||||
batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc."
|
||||
batch_prompt += (
|
||||
"For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. "
|
||||
"Format your response as:\nCHUNK 1: [context]\nCHUNK 2: [context]\netc."
|
||||
)
|
||||
|
||||
# Make single API call for ALL chunks
|
||||
response = await client.chat.completions.create(
|
||||
model=model_choice,
|
||||
messages=[
|
||||
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
|
||||
batch_params = {
|
||||
"model": model_choice,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that generates contextual information for document chunks.",
|
||||
},
|
||||
{"role": "user", "content": batch_prompt},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100 * len(chunks), # Limit response size
|
||||
)
|
||||
"temperature": 0,
|
||||
"max_tokens": (600 if requires_max_completion_tokens(model_choice) else 100) * len(chunks), # Much more tokens for reasoning models (GPT-5 needs extra reasoning space)
|
||||
}
|
||||
final_batch_params = prepare_chat_completion_params(model_choice, batch_params)
|
||||
response = await client.chat.completions.create(**final_batch_params)
|
||||
|
||||
# Parse response
|
||||
response_text = response.choices[0].message.content
|
||||
choice = response.choices[0] if response.choices else None
|
||||
response_text, _, _ = extract_message_text(choice)
|
||||
if not response_text:
|
||||
search_logger.error(
|
||||
"Empty response from LLM when generating contextual embeddings batch"
|
||||
)
|
||||
return [(chunk, False) for chunk in chunks]
|
||||
|
||||
# Extract contexts from response
|
||||
lines = response_text.strip().split("\\n")
|
||||
lines = response_text.strip().split("\n")
|
||||
chunk_contexts = {}
|
||||
|
||||
for line in lines:
|
||||
@ -245,4 +269,4 @@ async def generate_contextual_embeddings_batch(
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error in contextual embedding batch: {e}")
|
||||
# Return non-contextual for all chunks
|
||||
return [(chunk, False) for chunk in chunks]
|
||||
return [(chunk, False) for chunk in chunks]
|
||||
|
||||
@ -13,7 +13,7 @@ import openai
|
||||
|
||||
from ...config.logfire_config import safe_span, search_logger
|
||||
from ..credential_service import credential_service
|
||||
from ..llm_provider_service import get_embedding_model, get_llm_client
|
||||
from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model
|
||||
from ..threading_service import get_threading_service
|
||||
from .embedding_exceptions import (
|
||||
EmbeddingAPIError,
|
||||
@ -152,34 +152,56 @@ async def create_embeddings_batch(
|
||||
if not texts:
|
||||
return EmbeddingBatchResult()
|
||||
|
||||
result = EmbeddingBatchResult()
|
||||
|
||||
# Validate that all items in texts are strings
|
||||
validated_texts = []
|
||||
for i, text in enumerate(texts):
|
||||
if not isinstance(text, str):
|
||||
search_logger.error(
|
||||
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
|
||||
)
|
||||
# Try to convert to string
|
||||
try:
|
||||
validated_texts.append(str(text))
|
||||
except Exception as e:
|
||||
search_logger.error(
|
||||
f"Failed to convert text at index {i} to string: {e}", exc_info=True
|
||||
)
|
||||
validated_texts.append("") # Use empty string as fallback
|
||||
else:
|
||||
if isinstance(text, str):
|
||||
validated_texts.append(text)
|
||||
continue
|
||||
|
||||
search_logger.error(
|
||||
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
|
||||
)
|
||||
try:
|
||||
converted = str(text)
|
||||
validated_texts.append(converted)
|
||||
except Exception as conversion_error:
|
||||
search_logger.error(
|
||||
f"Failed to convert text at index {i} to string: {conversion_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
result.add_failure(
|
||||
repr(text),
|
||||
EmbeddingAPIError("Invalid text type", original_error=conversion_error),
|
||||
batch_index=None,
|
||||
)
|
||||
|
||||
texts = validated_texts
|
||||
|
||||
result = EmbeddingBatchResult()
|
||||
threading_service = get_threading_service()
|
||||
|
||||
with safe_span(
|
||||
"create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
|
||||
) as span:
|
||||
try:
|
||||
async with get_llm_client(provider=provider, use_embedding_provider=True) as client:
|
||||
# Intelligent embedding provider routing based on model type
|
||||
# Get the embedding model first to determine the correct provider
|
||||
embedding_model = await get_embedding_model(provider=provider)
|
||||
|
||||
# Route to correct provider based on model type
|
||||
if is_google_embedding_model(embedding_model):
|
||||
embedding_provider = "google"
|
||||
search_logger.info(f"Routing to Google for embedding model: {embedding_model}")
|
||||
elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower():
|
||||
embedding_provider = "openai"
|
||||
search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}")
|
||||
else:
|
||||
# Keep original provider for ollama and other providers
|
||||
embedding_provider = provider
|
||||
search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}")
|
||||
|
||||
async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
|
||||
# Load batch size and dimensions from settings
|
||||
try:
|
||||
rag_settings = await credential_service.get_credentials_by_category(
|
||||
@ -220,7 +242,8 @@ async def create_embeddings_batch(
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Create embeddings for this batch
|
||||
embedding_model = await get_embedding_model(provider=provider)
|
||||
embedding_model = await get_embedding_model(provider=embedding_provider)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@
|
||||
Provider Discovery Service
|
||||
|
||||
Discovers available models, checks provider health, and provides model specifications
|
||||
for OpenAI, Google Gemini, Ollama, and Anthropic providers.
|
||||
for OpenAI, Google Gemini, Ollama, Anthropic, and Grok providers.
|
||||
"""
|
||||
|
||||
import time
|
||||
@ -359,6 +359,36 @@ class ProviderDiscoveryService:
|
||||
|
||||
return models
|
||||
|
||||
async def discover_grok_models(self, api_key: str) -> list[ModelSpec]:
|
||||
"""Discover available Grok models."""
|
||||
cache_key = f"grok_models_{hash(api_key)}"
|
||||
cached = self._get_cached_result(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
models = []
|
||||
try:
|
||||
# Grok model specifications
|
||||
model_specs = [
|
||||
ModelSpec("grok-3-mini", "grok", 32768, True, True, False, None, 0.15, 0.60, "Fast and efficient Grok model"),
|
||||
ModelSpec("grok-3", "grok", 32768, True, True, False, None, 2.00, 10.00, "Standard Grok model"),
|
||||
ModelSpec("grok-4", "grok", 32768, True, True, False, None, 5.00, 25.00, "Advanced Grok model"),
|
||||
ModelSpec("grok-2-vision", "grok", 8192, True, True, True, None, 3.00, 15.00, "Grok model with vision capabilities"),
|
||||
ModelSpec("grok-2-latest", "grok", 8192, True, True, False, None, 2.00, 10.00, "Latest Grok 2 model"),
|
||||
]
|
||||
|
||||
# Test connectivity - Grok doesn't have a models list endpoint,
|
||||
# so we'll just return the known models if API key is provided
|
||||
if api_key:
|
||||
models = model_specs
|
||||
self._cache_result(cache_key, models)
|
||||
logger.info(f"Discovered {len(models)} Grok models")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error discovering Grok models: {e}")
|
||||
|
||||
return models
|
||||
|
||||
async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus:
|
||||
"""Check health and connectivity status of a provider."""
|
||||
start_time = time.time()
|
||||
@ -456,6 +486,23 @@ class ProviderDiscoveryService:
|
||||
last_checked=time.time()
|
||||
)
|
||||
|
||||
elif provider == "grok":
|
||||
api_key = config.get("api_key")
|
||||
if not api_key:
|
||||
return ProviderStatus(provider, False, None, "API key not configured")
|
||||
|
||||
# Grok doesn't have a health check endpoint, so we'll assume it's available
|
||||
# if API key is provided. In a real implementation, you might want to make a
|
||||
# small test request to verify the key is valid.
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
return ProviderStatus(
|
||||
provider="grok",
|
||||
is_available=True,
|
||||
response_time_ms=response_time,
|
||||
models_available=5, # Known model count
|
||||
last_checked=time.time()
|
||||
)
|
||||
|
||||
else:
|
||||
return ProviderStatus(provider, False, None, f"Unknown provider: {provider}")
|
||||
|
||||
@ -496,6 +543,11 @@ class ProviderDiscoveryService:
|
||||
if anthropic_key:
|
||||
providers["anthropic"] = await self.discover_anthropic_models(anthropic_key)
|
||||
|
||||
# Grok
|
||||
grok_key = await credential_service.get_credential("GROK_API_KEY")
|
||||
if grok_key:
|
||||
providers["grok"] = await self.discover_grok_models(grok_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all available models: {e}")
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from supabase import Client
|
||||
|
||||
from ..config.logfire_config import get_logger, search_logger
|
||||
from .client_manager import get_supabase_client
|
||||
from .llm_provider_service import get_llm_client
|
||||
from .llm_provider_service import extract_message_text, get_llm_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -72,20 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a
|
||||
)
|
||||
|
||||
# Extract the generated summary with proper error handling
|
||||
if not response or not response.choices or len(response.choices) == 0:
|
||||
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
|
||||
return default_summary
|
||||
|
||||
message_content = response.choices[0].message.content
|
||||
if message_content is None:
|
||||
search_logger.error(f"LLM returned None content for {source_id}")
|
||||
return default_summary
|
||||
|
||||
summary = message_content.strip()
|
||||
|
||||
# Ensure the summary is not too long
|
||||
if len(summary) > max_length:
|
||||
summary = summary[:max_length] + "..."
|
||||
if not response or not response.choices or len(response.choices) == 0:
|
||||
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
|
||||
return default_summary
|
||||
|
||||
choice = response.choices[0]
|
||||
summary_text, _, _ = extract_message_text(choice)
|
||||
if not summary_text:
|
||||
search_logger.error(f"LLM returned None content for {source_id}")
|
||||
return default_summary
|
||||
|
||||
summary = summary_text.strip()
|
||||
|
||||
# Ensure the summary is not too long
|
||||
if len(summary) > max_length:
|
||||
summary = summary[:max_length] + "..."
|
||||
|
||||
return summary
|
||||
|
||||
@ -187,7 +188,9 @@ Generate only the title, nothing else."""
|
||||
],
|
||||
)
|
||||
|
||||
generated_title = response.choices[0].message.content.strip()
|
||||
choice = response.choices[0]
|
||||
generated_title, _, _ = extract_message_text(choice)
|
||||
generated_title = generated_title.strip()
|
||||
# Clean up the title
|
||||
generated_title = generated_title.strip("\"'")
|
||||
if len(generated_title) < 50: # Sanity check
|
||||
|
||||
@ -8,6 +8,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from difflib import SequenceMatcher
|
||||
@ -17,25 +18,104 @@ from urllib.parse import urlparse
|
||||
from supabase import Client
|
||||
|
||||
from ...config.logfire_config import search_logger
|
||||
from ..credential_service import credential_service
|
||||
from ..embeddings.contextual_embedding_service import generate_contextual_embeddings_batch
|
||||
from ..embeddings.embedding_service import create_embeddings_batch
|
||||
from ..llm_provider_service import (
|
||||
extract_json_from_reasoning,
|
||||
extract_message_text,
|
||||
get_llm_client,
|
||||
prepare_chat_completion_params,
|
||||
synthesize_json_from_reasoning,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_choice() -> str:
|
||||
"""Get MODEL_CHOICE with direct fallback."""
|
||||
def _extract_json_payload(raw_response: str, context_code: str = "", language: str = "") -> str:
|
||||
"""Return the best-effort JSON object from an LLM response."""
|
||||
|
||||
if not raw_response:
|
||||
return raw_response
|
||||
|
||||
cleaned = raw_response.strip()
|
||||
|
||||
# Check if this looks like reasoning text first
|
||||
if _is_reasoning_text_response(cleaned):
|
||||
# Try intelligent extraction from reasoning text with context
|
||||
extracted = extract_json_from_reasoning(cleaned, context_code, language)
|
||||
if extracted:
|
||||
return extracted
|
||||
# extract_json_from_reasoning may return nothing; synthesize a fallback JSON if so\
|
||||
fallback_json = synthesize_json_from_reasoning("", context_code, language)
|
||||
if fallback_json:
|
||||
return fallback_json
|
||||
# If all else fails, return a minimal valid JSON object to avoid downstream errors
|
||||
return '{"example_name": "Code Example", "summary": "Code example extracted from context."}'
|
||||
|
||||
|
||||
if cleaned.startswith("```"):
|
||||
lines = cleaned.splitlines()
|
||||
# Drop opening fence
|
||||
lines = lines[1:]
|
||||
# Drop closing fence if present
|
||||
if lines and lines[-1].strip().startswith("```"):
|
||||
lines = lines[:-1]
|
||||
cleaned = "\n".join(lines).strip()
|
||||
|
||||
# Trim any leading/trailing text outside the outermost JSON braces
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start != -1 and end != -1 and end >= start:
|
||||
cleaned = cleaned[start : end + 1]
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
REASONING_STARTERS = [
|
||||
"okay, let's see", "okay, let me", "let me think", "first, i need to", "looking at this",
|
||||
"i need to", "analyzing", "let me work through", "thinking about", "let me see"
|
||||
]
|
||||
|
||||
def _is_reasoning_text_response(text: str) -> bool:
|
||||
"""Detect if response is reasoning text rather than direct JSON."""
|
||||
if not text or len(text) < 20:
|
||||
return False
|
||||
|
||||
text_lower = text.lower().strip()
|
||||
|
||||
# Check if it's clearly not JSON (starts with reasoning text)
|
||||
starts_with_reasoning = any(text_lower.startswith(starter) for starter in REASONING_STARTERS)
|
||||
|
||||
# Check if it lacks immediate JSON structure
|
||||
lacks_immediate_json = not text_lower.lstrip().startswith('{')
|
||||
|
||||
return starts_with_reasoning or (lacks_immediate_json and any(pattern in text_lower for pattern in REASONING_STARTERS))
|
||||
async def _get_model_choice() -> str:
|
||||
"""Get MODEL_CHOICE with provider-aware defaults from centralized service."""
|
||||
try:
|
||||
# Direct cache/env fallback
|
||||
from ..credential_service import credential_service
|
||||
# Get the active provider configuration
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
active_provider = provider_config.get("provider", "openai")
|
||||
model = provider_config.get("chat_model")
|
||||
|
||||
if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache:
|
||||
model = credential_service._cache["MODEL_CHOICE"]
|
||||
else:
|
||||
model = os.getenv("MODEL_CHOICE", "gpt-4.1-nano")
|
||||
search_logger.debug(f"Using model choice: {model}")
|
||||
# If no custom model is set, use provider-specific defaults
|
||||
if not model or model.strip() == "":
|
||||
# Provider-specific defaults
|
||||
provider_defaults = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"openrouter": "anthropic/claude-3.5-sonnet",
|
||||
"google": "gemini-1.5-flash",
|
||||
"ollama": "llama3.2:latest",
|
||||
"anthropic": "claude-3-5-haiku-20241022",
|
||||
"grok": "grok-3-mini"
|
||||
}
|
||||
model = provider_defaults.get(active_provider, "gpt-4o-mini")
|
||||
search_logger.debug(f"Using default model for provider {active_provider}: {model}")
|
||||
|
||||
search_logger.debug(f"Using model for provider {active_provider}: {model}")
|
||||
return model
|
||||
except Exception as e:
|
||||
search_logger.warning(f"Error getting model choice: {e}, using default")
|
||||
return "gpt-4.1-nano"
|
||||
return "gpt-4o-mini"
|
||||
|
||||
|
||||
def _get_max_workers() -> int:
|
||||
@ -155,6 +235,7 @@ def _select_best_code_variant(similar_blocks: list[dict[str, Any]]) -> dict[str,
|
||||
return best_block
|
||||
|
||||
|
||||
|
||||
def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract code blocks from markdown content along with context.
|
||||
@ -168,8 +249,6 @@ def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[d
|
||||
"""
|
||||
# Load all code extraction settings with direct fallback
|
||||
try:
|
||||
from ...services.credential_service import credential_service
|
||||
|
||||
def _get_setting_fallback(key: str, default: str) -> str:
|
||||
if credential_service._cache_initialized and key in credential_service._cache:
|
||||
return credential_service._cache[key]
|
||||
@ -507,7 +586,7 @@ def generate_code_example_summary(
|
||||
A dictionary with 'summary' and 'example_name'
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
# Run the async version in the current thread
|
||||
return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider))
|
||||
|
||||
@ -518,13 +597,22 @@ async def _generate_code_example_summary_async(
|
||||
"""
|
||||
Async version of generate_code_example_summary using unified LLM provider service.
|
||||
"""
|
||||
from ..llm_provider_service import get_llm_client
|
||||
|
||||
# Get model choice from credential service (RAG setting)
|
||||
model_choice = _get_model_choice()
|
||||
|
||||
# Create the prompt
|
||||
prompt = f"""<context_before>
|
||||
# Get model choice from credential service (RAG setting)
|
||||
model_choice = await _get_model_choice()
|
||||
|
||||
# If provider is not specified, get it from credential service
|
||||
if provider is None:
|
||||
try:
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
provider = provider_config.get("provider", "openai")
|
||||
search_logger.debug(f"Auto-detected provider from credential service: {provider}")
|
||||
except Exception as e:
|
||||
search_logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
|
||||
provider = "openai"
|
||||
|
||||
# Create the prompt variants: base prompt, guarded prompt (JSON reminder), and strict prompt for retries
|
||||
base_prompt = f"""<context_before>
|
||||
{context_before[-500:] if len(context_before) > 500 else context_before}
|
||||
</context_before>
|
||||
|
||||
@ -548,6 +636,16 @@ Format your response as JSON:
|
||||
"summary": "2-3 sentence description of what the code demonstrates"
|
||||
}}
|
||||
"""
|
||||
guard_prompt = (
|
||||
base_prompt
|
||||
+ "\n\nImportant: Respond with a valid JSON object that exactly matches the keys "
|
||||
'{"example_name": string, "summary": string}. Do not include commentary, '
|
||||
"markdown fences, or reasoning notes."
|
||||
)
|
||||
strict_prompt = (
|
||||
guard_prompt
|
||||
+ "\n\nSecond attempt enforcement: Return JSON only with the exact schema. No additional text or reasoning content."
|
||||
)
|
||||
|
||||
try:
|
||||
# Use unified LLM provider service
|
||||
@ -555,25 +653,261 @@ Format your response as JSON:
|
||||
search_logger.info(
|
||||
f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}"
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_choice,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=500,
|
||||
temperature=0.3,
|
||||
|
||||
provider_lower = provider.lower()
|
||||
is_grok_model = (provider_lower == "grok") or ("grok" in model_choice.lower())
|
||||
|
||||
supports_response_format_base = (
|
||||
provider_lower in {"openai", "google", "anthropic"}
|
||||
or (provider_lower == "openrouter" and model_choice.startswith("openai/"))
|
||||
)
|
||||
|
||||
response_content = response.choices[0].message.content.strip()
|
||||
last_response_obj = None
|
||||
last_elapsed_time = None
|
||||
last_response_content = ""
|
||||
last_json_error: json.JSONDecodeError | None = None
|
||||
|
||||
for enforce_json, current_prompt in ((False, guard_prompt), (True, strict_prompt)):
|
||||
request_params = {
|
||||
"model": model_choice,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
|
||||
},
|
||||
{"role": "user", "content": current_prompt},
|
||||
],
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
should_use_response_format = False
|
||||
if enforce_json:
|
||||
if not is_grok_model and (supports_response_format_base or provider_lower == "openrouter"):
|
||||
should_use_response_format = True
|
||||
else:
|
||||
if supports_response_format_base:
|
||||
should_use_response_format = True
|
||||
|
||||
if should_use_response_format:
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
|
||||
if is_grok_model:
|
||||
unsupported_params = ["presence_penalty", "frequency_penalty", "stop", "reasoning_effort"]
|
||||
for param in unsupported_params:
|
||||
if param in request_params:
|
||||
removed_value = request_params.pop(param)
|
||||
search_logger.warning(f"Removed unsupported Grok parameter '{param}': {removed_value}")
|
||||
|
||||
supported_params = ["model", "messages", "max_tokens", "temperature", "response_format", "stream", "tools", "tool_choice"]
|
||||
for param in list(request_params.keys()):
|
||||
if param not in supported_params:
|
||||
search_logger.warning(f"Parameter '{param}' may not be supported by Grok reasoning models")
|
||||
|
||||
start_time = time.time()
|
||||
max_retries = 3 if is_grok_model else 1
|
||||
retry_delay = 1.0
|
||||
response_content_local = ""
|
||||
reasoning_text_local = ""
|
||||
json_error_occurred = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if is_grok_model and attempt > 0:
|
||||
search_logger.info(f"Grok retry attempt {attempt + 1}/{max_retries} after {retry_delay:.1f}s delay")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
final_params = prepare_chat_completion_params(model_choice, request_params)
|
||||
response = await client.chat.completions.create(**final_params)
|
||||
last_response_obj = response
|
||||
|
||||
choice = response.choices[0] if response.choices else None
|
||||
message = choice.message if choice and hasattr(choice, "message") else None
|
||||
response_content_local = ""
|
||||
reasoning_text_local = ""
|
||||
|
||||
if choice:
|
||||
response_content_local, reasoning_text_local, _ = extract_message_text(choice)
|
||||
|
||||
# Enhanced logging for response analysis
|
||||
if message and reasoning_text_local:
|
||||
content_preview = response_content_local[:100] if response_content_local else "None"
|
||||
reasoning_preview = reasoning_text_local[:100] if reasoning_text_local else "None"
|
||||
search_logger.debug(
|
||||
f"Response has reasoning content - content: '{content_preview}', reasoning: '{reasoning_preview}'"
|
||||
)
|
||||
|
||||
if response_content_local:
|
||||
last_response_content = response_content_local.strip()
|
||||
|
||||
# Pre-validate response before processing
|
||||
if len(last_response_content) < 20 or (len(last_response_content) < 50 and not last_response_content.strip().startswith('{')):
|
||||
# Very minimal response - likely "Okay\nOkay" type
|
||||
search_logger.debug(f"Minimal response detected: {repr(last_response_content)}")
|
||||
# Generate fallback directly from context
|
||||
fallback_json = synthesize_json_from_reasoning("", code, language)
|
||||
if fallback_json:
|
||||
try:
|
||||
result = json.loads(fallback_json)
|
||||
final_result = {
|
||||
"example_name": result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
|
||||
"summary": result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
search_logger.info(f"Generated fallback summary from context - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
|
||||
return final_result
|
||||
except json.JSONDecodeError:
|
||||
pass # Continue to normal error handling
|
||||
else:
|
||||
# Even synthesis failed - provide hardcoded fallback for minimal responses
|
||||
final_result = {
|
||||
"example_name": f"Code Example{f' ({language})' if language else ''}",
|
||||
"summary": "Code example extracted from development context.",
|
||||
}
|
||||
search_logger.info(f"Used hardcoded fallback for minimal response - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
|
||||
return final_result
|
||||
|
||||
payload = _extract_json_payload(last_response_content, code, language)
|
||||
if payload != last_response_content:
|
||||
search_logger.debug(
|
||||
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
|
||||
)
|
||||
|
||||
try:
|
||||
result = json.loads(payload)
|
||||
|
||||
if not result.get("example_name") or not result.get("summary"):
|
||||
search_logger.warning(f"Incomplete response from LLM: {result}")
|
||||
|
||||
final_result = {
|
||||
"example_name": result.get(
|
||||
"example_name", f"Code Example{f' ({language})' if language else ''}"
|
||||
),
|
||||
"summary": result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
|
||||
search_logger.info(
|
||||
f"Generated code example summary - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}"
|
||||
)
|
||||
return final_result
|
||||
|
||||
except json.JSONDecodeError as json_error:
|
||||
last_json_error = json_error
|
||||
json_error_occurred = True
|
||||
snippet = last_response_content[:200]
|
||||
if not enforce_json:
|
||||
# Check if this was reasoning text that couldn't be parsed
|
||||
if _is_reasoning_text_response(last_response_content):
|
||||
search_logger.debug(
|
||||
f"Reasoning text detected but no JSON extracted. Response snippet: {repr(snippet)}"
|
||||
)
|
||||
else:
|
||||
search_logger.warning(
|
||||
f"Failed to parse JSON response from LLM (non-strict attempt). Error: {json_error}. Response snippet: {repr(snippet)}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
search_logger.error(
|
||||
f"Strict JSON enforcement still failed to produce valid JSON: {json_error}. Response snippet: {repr(snippet)}"
|
||||
)
|
||||
break
|
||||
|
||||
elif is_grok_model and attempt < max_retries - 1:
|
||||
search_logger.warning(f"Grok empty response on attempt {attempt + 1}, retrying...")
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if is_grok_model and attempt < max_retries - 1:
|
||||
search_logger.error(f"Grok request failed on attempt {attempt + 1}: {e}, retrying...")
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if is_grok_model:
|
||||
elapsed_time = time.time() - start_time
|
||||
last_elapsed_time = elapsed_time
|
||||
search_logger.debug(f"Grok total response time: {elapsed_time:.2f}s")
|
||||
|
||||
if json_error_occurred:
|
||||
if not enforce_json:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
if response_content_local:
|
||||
# We would have returned already on success; if we reach here, parsing failed but we are not retrying
|
||||
continue
|
||||
|
||||
response_content = last_response_content
|
||||
response = last_response_obj
|
||||
elapsed_time = last_elapsed_time if last_elapsed_time is not None else 0.0
|
||||
|
||||
if last_json_error is not None and response_content:
|
||||
search_logger.error(
|
||||
f"LLM response after strict enforcement was still not valid JSON: {last_json_error}. Clearing response to trigger error handling."
|
||||
)
|
||||
response_content = ""
|
||||
|
||||
if not response_content:
|
||||
search_logger.error(f"Empty response from LLM for model: {model_choice} (provider: {provider})")
|
||||
if is_grok_model:
|
||||
search_logger.error("Grok empty response debugging:")
|
||||
search_logger.error(f" - Request took: {elapsed_time:.2f}s")
|
||||
search_logger.error(f" - Response status: {getattr(response, 'status_code', 'N/A')}")
|
||||
search_logger.error(f" - Response headers: {getattr(response, 'headers', 'N/A')}")
|
||||
search_logger.error(f" - Full response: {response}")
|
||||
search_logger.error(f" - Response choices length: {len(response.choices) if response.choices else 0}")
|
||||
if response.choices:
|
||||
search_logger.error(f" - First choice: {response.choices[0]}")
|
||||
search_logger.error(f" - Message content: '{response.choices[0].message.content}'")
|
||||
search_logger.error(f" - Message role: {response.choices[0].message.role}")
|
||||
search_logger.error("Check: 1) API key validity, 2) rate limits, 3) model availability")
|
||||
|
||||
# Implement fallback for Grok failures
|
||||
search_logger.warning("Attempting fallback to OpenAI due to Grok failure...")
|
||||
try:
|
||||
# Use OpenAI as fallback with similar parameters
|
||||
fallback_params = {
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": request_params["messages"],
|
||||
"temperature": request_params.get("temperature", 0.1),
|
||||
"max_tokens": request_params.get("max_tokens", 500),
|
||||
}
|
||||
|
||||
async with get_llm_client(provider="openai") as fallback_client:
|
||||
fallback_response = await fallback_client.chat.completions.create(**fallback_params)
|
||||
fallback_content = fallback_response.choices[0].message.content
|
||||
if fallback_content and fallback_content.strip():
|
||||
search_logger.info("gpt-4o-mini fallback succeeded")
|
||||
response_content = fallback_content.strip()
|
||||
else:
|
||||
search_logger.error("gpt-4o-mini fallback also returned empty response")
|
||||
raise ValueError(f"Both {model_choice} and gpt-4o-mini fallback failed")
|
||||
|
||||
except Exception as fallback_error:
|
||||
search_logger.error(f"gpt-4o-mini fallback failed: {fallback_error}")
|
||||
raise ValueError(f"{model_choice} failed and fallback to gpt-4o-mini also failed: {fallback_error}") from fallback_error
|
||||
else:
|
||||
search_logger.debug(f"Full response object: {response}")
|
||||
raise ValueError("Empty response from LLM")
|
||||
|
||||
if not response_content:
|
||||
# This should not happen after fallback logic, but safety check
|
||||
raise ValueError("No valid response content after all attempts")
|
||||
|
||||
response_content = response_content.strip()
|
||||
search_logger.debug(f"LLM API response: {repr(response_content[:200])}...")
|
||||
|
||||
result = json.loads(response_content)
|
||||
payload = _extract_json_payload(response_content, code, language)
|
||||
if payload != response_content:
|
||||
search_logger.debug(
|
||||
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
|
||||
)
|
||||
|
||||
result = json.loads(payload)
|
||||
|
||||
# Validate the response has the required fields
|
||||
if not result.get("example_name") or not result.get("summary"):
|
||||
@ -595,12 +929,38 @@ Format your response as JSON:
|
||||
search_logger.error(
|
||||
f"Failed to parse JSON response from LLM: {e}, Response: {repr(response_content) if 'response_content' in locals() else 'No response'}"
|
||||
)
|
||||
# Try to generate context-aware fallback
|
||||
try:
|
||||
fallback_json = synthesize_json_from_reasoning("", code, language)
|
||||
if fallback_json:
|
||||
fallback_result = json.loads(fallback_json)
|
||||
search_logger.info(f"Generated context-aware fallback summary")
|
||||
return {
|
||||
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
|
||||
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
except Exception:
|
||||
pass # Fall through to generic fallback
|
||||
|
||||
return {
|
||||
"example_name": f"Code Example{f' ({language})' if language else ''}",
|
||||
"summary": "Code example for demonstration purposes.",
|
||||
}
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error generating code summary using unified LLM provider: {e}")
|
||||
# Try to generate context-aware fallback
|
||||
try:
|
||||
fallback_json = synthesize_json_from_reasoning("", code, language)
|
||||
if fallback_json:
|
||||
fallback_result = json.loads(fallback_json)
|
||||
search_logger.info(f"Generated context-aware fallback summary after error")
|
||||
return {
|
||||
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
|
||||
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
except Exception:
|
||||
pass # Fall through to generic fallback
|
||||
|
||||
return {
|
||||
"example_name": f"Code Example{f' ({language})' if language else ''}",
|
||||
"summary": "Code example for demonstration purposes.",
|
||||
@ -608,7 +968,7 @@ Format your response as JSON:
|
||||
|
||||
|
||||
async def generate_code_summaries_batch(
|
||||
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None
|
||||
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None, provider: str = None
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Generate summaries for multiple code blocks with rate limiting and proper worker management.
|
||||
@ -617,6 +977,7 @@ async def generate_code_summaries_batch(
|
||||
code_blocks: List of code block dictionaries
|
||||
max_workers: Maximum number of concurrent API requests
|
||||
progress_callback: Optional callback for progress updates (async function)
|
||||
provider: LLM provider to use for generation (e.g., 'grok', 'openai', 'anthropic')
|
||||
|
||||
Returns:
|
||||
List of summary dictionaries
|
||||
@ -627,8 +988,6 @@ async def generate_code_summaries_batch(
|
||||
# Get max_workers from settings if not provided
|
||||
if max_workers is None:
|
||||
try:
|
||||
from ...services.credential_service import credential_service
|
||||
|
||||
if (
|
||||
credential_service._cache_initialized
|
||||
and "CODE_SUMMARY_MAX_WORKERS" in credential_service._cache
|
||||
@ -663,6 +1022,7 @@ async def generate_code_summaries_batch(
|
||||
block["context_before"],
|
||||
block["context_after"],
|
||||
block.get("language", ""),
|
||||
provider,
|
||||
)
|
||||
|
||||
# Update progress
|
||||
@ -757,29 +1117,17 @@ async def add_code_examples_to_supabase(
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error deleting existing code examples for {url}: {e}")
|
||||
|
||||
# Check if contextual embeddings are enabled
|
||||
# Check if contextual embeddings are enabled (use proper async method like document storage)
|
||||
try:
|
||||
from ..credential_service import credential_service
|
||||
|
||||
use_contextual_embeddings = credential_service._cache.get("USE_CONTEXTUAL_EMBEDDINGS")
|
||||
if isinstance(use_contextual_embeddings, str):
|
||||
use_contextual_embeddings = use_contextual_embeddings.lower() == "true"
|
||||
elif isinstance(use_contextual_embeddings, dict) and use_contextual_embeddings.get(
|
||||
"is_encrypted"
|
||||
):
|
||||
# Handle encrypted value
|
||||
encrypted_value = use_contextual_embeddings.get("encrypted_value")
|
||||
if encrypted_value:
|
||||
try:
|
||||
decrypted = credential_service._decrypt_value(encrypted_value)
|
||||
use_contextual_embeddings = decrypted.lower() == "true"
|
||||
except:
|
||||
use_contextual_embeddings = False
|
||||
else:
|
||||
use_contextual_embeddings = False
|
||||
raw_value = await credential_service.get_credential(
|
||||
"USE_CONTEXTUAL_EMBEDDINGS", "false", decrypt=True
|
||||
)
|
||||
if isinstance(raw_value, str):
|
||||
use_contextual_embeddings = raw_value.lower() == "true"
|
||||
else:
|
||||
use_contextual_embeddings = bool(use_contextual_embeddings)
|
||||
except:
|
||||
use_contextual_embeddings = bool(raw_value)
|
||||
except Exception as e:
|
||||
search_logger.error(f"DEBUG: Error reading contextual embeddings: {e}")
|
||||
# Fallback to environment variable
|
||||
use_contextual_embeddings = (
|
||||
os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false").lower() == "true"
|
||||
@ -848,14 +1196,13 @@ async def add_code_examples_to_supabase(
|
||||
# Use only successful embeddings
|
||||
valid_embeddings = result.embeddings
|
||||
successful_texts = result.texts_processed
|
||||
|
||||
|
||||
# Get model information for tracking
|
||||
from ..llm_provider_service import get_embedding_model
|
||||
from ..credential_service import credential_service
|
||||
|
||||
|
||||
# Get embedding model name
|
||||
embedding_model_name = await get_embedding_model(provider=provider)
|
||||
|
||||
|
||||
# Get LLM chat model (used for code summaries and contextual embeddings if enabled)
|
||||
llm_chat_model = None
|
||||
try:
|
||||
@ -868,7 +1215,7 @@ async def add_code_examples_to_supabase(
|
||||
llm_chat_model = await credential_service.get_credential("MODEL_CHOICE", "gpt-4o-mini")
|
||||
else:
|
||||
# For code summaries, we use MODEL_CHOICE
|
||||
llm_chat_model = _get_model_choice()
|
||||
llm_chat_model = await _get_model_choice()
|
||||
except Exception as e:
|
||||
search_logger.warning(f"Failed to get LLM chat model: {e}")
|
||||
llm_chat_model = "gpt-4o-mini" # Default fallback
|
||||
@ -888,7 +1235,7 @@ async def add_code_examples_to_supabase(
|
||||
positions_by_text[text].append(original_indices[k])
|
||||
|
||||
# Map successful texts back to their original indices
|
||||
for embedding, text in zip(valid_embeddings, successful_texts, strict=False):
|
||||
for embedding, text in zip(valid_embeddings, successful_texts, strict=True):
|
||||
# Get the next available index for this text (handles duplicates)
|
||||
if positions_by_text[text]:
|
||||
orig_idx = positions_by_text[text].popleft() # Original j index in [i, batch_end)
|
||||
@ -908,7 +1255,7 @@ async def add_code_examples_to_supabase(
|
||||
# Determine the correct embedding column based on dimension
|
||||
embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist())
|
||||
embedding_column = None
|
||||
|
||||
|
||||
if embedding_dim == 768:
|
||||
embedding_column = "embedding_768"
|
||||
elif embedding_dim == 1024:
|
||||
@ -918,10 +1265,12 @@ async def add_code_examples_to_supabase(
|
||||
elif embedding_dim == 3072:
|
||||
embedding_column = "embedding_3072"
|
||||
else:
|
||||
# Default to closest supported dimension
|
||||
search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536")
|
||||
embedding_column = "embedding_1536"
|
||||
|
||||
# Skip unsupported dimensions to avoid corrupting the schema
|
||||
search_logger.error(
|
||||
f"Unsupported embedding dimension {embedding_dim}; skipping record to prevent column mismatch"
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data.append({
|
||||
"url": urls[idx],
|
||||
"chunk_number": chunk_numbers[idx],
|
||||
@ -954,9 +1303,7 @@ async def add_code_examples_to_supabase(
|
||||
f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}"
|
||||
)
|
||||
search_logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
import time
|
||||
|
||||
time.sleep(retry_delay)
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
else:
|
||||
# Final attempt failed
|
||||
|
||||
@ -33,6 +33,12 @@ class AsyncContextManager:
|
||||
class TestAsyncLLMProviderService:
|
||||
"""Test suite for async LLM provider service functions"""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_client():
|
||||
client = MagicMock()
|
||||
client.aclose = AsyncMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache(self):
|
||||
"""Clear cache before each test"""
|
||||
@ -98,7 +104,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client() as client:
|
||||
@ -121,7 +127,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client() as client:
|
||||
@ -143,7 +149,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client() as client:
|
||||
@ -166,7 +172,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client(provider="openai") as client:
|
||||
@ -194,7 +200,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client(use_embedding_provider=True) as client:
|
||||
@ -225,7 +231,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
# Should fallback to Ollama instead of raising an error
|
||||
@ -426,7 +432,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
# First call should hit the credential service
|
||||
@ -464,7 +470,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
client_ref = None
|
||||
@ -474,6 +480,7 @@ class TestAsyncLLMProviderService:
|
||||
|
||||
# After context manager exits, should still have reference to client
|
||||
assert client_ref == mock_client
|
||||
mock_client.aclose.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_providers_in_sequence(self, mock_credential_service):
|
||||
@ -494,7 +501,7 @@ class TestAsyncLLMProviderService:
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_client = self._make_mock_client()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
for config in configs:
|
||||
|
||||
@ -104,13 +104,15 @@ class TestCodeExtractionSourceId:
|
||||
)
|
||||
|
||||
# Verify the correct source_id was passed (now with cancellation_check parameter)
|
||||
mock_extract.assert_called_once_with(
|
||||
crawl_results,
|
||||
url_to_full_document,
|
||||
source_id, # This should be the third argument
|
||||
None,
|
||||
None # cancellation_check parameter
|
||||
)
|
||||
mock_extract.assert_called_once()
|
||||
args, kwargs = mock_extract.call_args
|
||||
assert args[0] == crawl_results
|
||||
assert args[1] == url_to_full_document
|
||||
assert args[2] == source_id
|
||||
assert args[3] is None
|
||||
assert args[4] is None
|
||||
if len(args) > 5:
|
||||
assert args[5] is None
|
||||
assert result == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -174,4 +176,4 @@ class TestCodeExtractionSourceId:
|
||||
import inspect
|
||||
source = inspect.getsource(module)
|
||||
assert "from urllib.parse import urlparse" not in source, \
|
||||
"Should not import urlparse since we don't extract domain from URL anymore"
|
||||
"Should not import urlparse since we don't extract domain from URL anymore"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user