diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 3c73b2a2725..1b95da39aac 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -59,7 +59,7 @@ vi.mock("../../../core/config/ContextProxy", () => ({ import type { Mock } from "vitest" import * as fsSync from "fs" import NodeCache from "node-cache" -import { getModels, getModelsFromCache } from "../modelCache" +import { getModels, getModelsFromCache, getCacheKey } from "../modelCache" import { getLiteLLMModels } from "../litellm" import { getOpenRouterModels } from "../openrouter" import { getRequestyModels } from "../requesty" @@ -474,3 +474,227 @@ describe("empty cache protection", () => { }) }) }) + +describe("multi-instance provider caching", () => { + describe("getCacheKey", () => { + it("returns provider name for non-multi-instance providers", () => { + expect(getCacheKey({ provider: "openrouter" })).toBe("openrouter") + expect(getCacheKey({ provider: "huggingface" })).toBe("huggingface") + expect(getCacheKey({ provider: "vercel-ai-gateway" })).toBe("vercel-ai-gateway") + }) + + it("returns provider name for multi-instance providers without baseUrl", () => { + expect(getCacheKey({ provider: "litellm", apiKey: "test-key", baseUrl: "" })).toBe("litellm") + expect(getCacheKey({ provider: "ollama" })).toBe("ollama") + expect(getCacheKey({ provider: "lmstudio" })).toBe("lmstudio") + }) + + it("returns unique key for multi-instance providers with baseUrl", () => { + const key1 = getCacheKey({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:4000", + }) + const key2 = getCacheKey({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:5000", + }) + + // Keys should be different for different URLs + expect(key1).not.toBe(key2) + // Keys should start with provider name + expect(key1).toMatch(/^litellm_/) + expect(key2).toMatch(/^litellm_/) + }) + + it("returns same key for same provider and baseUrl", () => { + const key1 = getCacheKey({ + provider: "litellm", + apiKey: "key1", + baseUrl: "http://localhost:4000", + }) + const key2 = getCacheKey({ + provider: "litellm", + apiKey: "different-key", // Different API key should not affect cache key + baseUrl: "http://localhost:4000", + }) + + expect(key1).toBe(key2) + }) + + it("generates unique keys for all multi-instance providers", () => { + const baseUrl = "http://localhost:8080" + + const litellmKey = getCacheKey({ provider: "litellm", apiKey: "key", baseUrl }) + const ollamaKey = getCacheKey({ provider: "ollama", baseUrl }) + const lmstudioKey = getCacheKey({ provider: "lmstudio", baseUrl }) + const requestyKey = getCacheKey({ provider: "requesty", baseUrl }) + const deepinfraKey = getCacheKey({ provider: "deepinfra", baseUrl }) + const rooKey = getCacheKey({ provider: "roo", baseUrl }) + + // All should have hash suffix + expect(litellmKey).toMatch(/^litellm_[a-f0-9]{8}$/) + expect(ollamaKey).toMatch(/^ollama_[a-f0-9]{8}$/) + expect(lmstudioKey).toMatch(/^lmstudio_[a-f0-9]{8}$/) + expect(requestyKey).toMatch(/^requesty_[a-f0-9]{8}$/) + expect(deepinfraKey).toMatch(/^deepinfra_[a-f0-9]{8}$/) + expect(rooKey).toMatch(/^roo_[a-f0-9]{8}$/) + }) + }) + + describe("getModels with different LiteLLM instances", () => { + let mockCache: any + let mockSet: Mock + + beforeEach(() => { + vi.clearAllMocks() + const MockedNodeCache = vi.mocked(NodeCache) + mockCache = new MockedNodeCache() + mockSet = mockCache.set + mockCache.get.mockReturnValue(undefined) + }) + + it("uses unique cache keys for different LiteLLM base URLs", async () => { + const modelsInstance1 = { + "model-a": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + description: "Model A", + }, + } + + mockGetLiteLLMModels.mockResolvedValue(modelsInstance1) + + await getModels({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:4000", + }) + + // Should cache with unique key including hash + expect(mockSet).toHaveBeenCalledWith(expect.stringMatching(/^litellm_[a-f0-9]{8}$/), modelsInstance1) + }) + + it("caches separately for different LiteLLM instances", async () => { + const modelsInstance1 = { + "model-instance-1": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + description: "Model from instance 1", + }, + } + const modelsInstance2 = { + "model-instance-2": { + maxTokens: 8192, + contextWindow: 200000, + supportsPromptCache: true, + description: "Model from instance 2", + }, + } + + // First instance returns models + mockGetLiteLLMModels.mockResolvedValueOnce(modelsInstance1) + await getModels({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:4000", + }) + + // Second instance returns different models + mockGetLiteLLMModels.mockResolvedValueOnce(modelsInstance2) + await getModels({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:5000", + }) + + // Should have been called twice with different cache keys + expect(mockSet).toHaveBeenCalledTimes(2) + + // Get the cache keys used + const cacheKey1 = mockSet.mock.calls[0][0] + const cacheKey2 = mockSet.mock.calls[1][0] + + // Keys should be different + expect(cacheKey1).not.toBe(cacheKey2) + + // Both should be litellm keys + expect(cacheKey1).toMatch(/^litellm_/) + expect(cacheKey2).toMatch(/^litellm_/) + }) + }) + + describe("refreshModels with different instances", () => { + let mockCache: any + let mockSet: Mock + + beforeEach(() => { + vi.clearAllMocks() + const MockedNodeCache = vi.mocked(NodeCache) + mockCache = new MockedNodeCache() + mockSet = mockCache.set + mockCache.get.mockReturnValue(undefined) + }) + + it("tracks in-flight requests separately for different instances", async () => { + const models1 = { + "model-1": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + }, + } + const models2 = { + "model-2": { + maxTokens: 8192, + contextWindow: 200000, + supportsPromptCache: false, + }, + } + + // Create delayed responses to simulate API latency + let resolve1: (value: typeof models1) => void + let resolve2: (value: typeof models2) => void + const promise1 = new Promise((resolve) => { + resolve1 = resolve + }) + const promise2 = new Promise((resolve) => { + resolve2 = resolve + }) + + mockGetLiteLLMModels + .mockReturnValueOnce(promise1) + .mockReturnValueOnce(promise2) + + const { refreshModels } = await import("../modelCache") + + // Start concurrent refreshes for different instances + const refresh1 = refreshModels({ + provider: "litellm", + apiKey: "key", + baseUrl: "http://instance1:4000", + }) + const refresh2 = refreshModels({ + provider: "litellm", + apiKey: "key", + baseUrl: "http://instance2:5000", + }) + + // Both should call the API since they are different instances + expect(mockGetLiteLLMModels).toHaveBeenCalledTimes(2) + + // Resolve both + resolve1!(models1) + resolve2!(models2) + + const [result1, result2] = await Promise.all([refresh1, refresh2]) + + // Results should be different + expect(result1).toEqual(models1) + expect(result2).toEqual(models2) + }) + }) +}) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index d22abf9c91c..dd52e321357 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -35,24 +35,106 @@ const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) // Zod schema for validating ModelRecord structure from disk cache const modelRecordSchema = z.record(z.string(), modelInfoSchema) -// Track in-flight refresh requests to prevent concurrent API calls for the same provider +// Track in-flight refresh requests to prevent concurrent API calls for the same provider+instance // This prevents race conditions where multiple calls might overwrite each other's results -const inFlightRefresh = new Map>() +// Uses unique cache keys that include base URL hash for multi-instance providers +const inFlightRefresh = new Map>() + +// Providers that can have multiple instances with different base URLs +// These need unique cache keys to prevent model mixing between configurations +const multiInstanceProviders = new Set([ + "litellm", + "ollama", + "lmstudio", + "requesty", + "deepinfra", + "roo", +]) -async function writeModels(router: RouterName, data: ModelRecord) { - const filename = `${router}_models.json` +/** + * Simple hash function for generating short, deterministic hashes from strings. + * Uses djb2 algorithm which produces good distribution for URL-like strings. + * @param str - String to hash + * @returns 8-character hex hash + */ +function simpleHash(str: string): string { + let hash = 5381 + for (let i = 0; i < str.length; i++) { + hash = (hash * 33) ^ str.charCodeAt(i) + } + // Convert to unsigned 32-bit integer and then to hex + return (hash >>> 0).toString(16).padStart(8, "0") +} + +/** + * Generate a unique cache key for a provider based on its configuration. + * For multi-instance providers (LiteLLM, Ollama, etc.), includes a hash of the + * base URL to ensure different instances maintain separate caches. + * + * @param options - Provider options including provider name and optional baseUrl + * @returns Unique cache key string + */ +export function getCacheKey(options: GetModelsOptions): string { + const { provider } = options + const baseUrl = "baseUrl" in options ? options.baseUrl : undefined + + // For multi-instance providers with a base URL, include a hash of the URL + if (multiInstanceProviders.has(provider) && baseUrl) { + const urlHash = simpleHash(baseUrl) + return `${provider}_${urlHash}` + } + + return provider +} + +/** + * Generate a unique filename for disk cache based on provider configuration. + * Uses the same unique key logic as getCacheKey. + * + * @param options - Provider options including provider name and optional baseUrl + * @returns Filename for the cache file (without path) + */ +function getCacheFilename(options: GetModelsOptions): string { + return `${getCacheKey(options)}_models.json` +} + +async function writeModels(options: GetModelsOptions, data: ModelRecord) { + const filename = getCacheFilename(options) const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) await safeWriteJson(path.join(cacheDir, filename), data) } -async function readModels(router: RouterName): Promise { - const filename = `${router}_models.json` +async function readModels(options: GetModelsOptions): Promise { + const filename = getCacheFilename(options) const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) const filePath = path.join(cacheDir, filename) const exists = await fileExistsAtPath(filePath) return exists ? JSON.parse(await fs.readFile(filePath, "utf8")) : undefined } +/** + * Delete the disk cache file for a specific provider configuration. + * Used during flush to ensure both memory and disk caches are cleared. + * + * @param options - Provider options including provider name and optional baseUrl + */ +async function deleteModelsCacheFile(options: GetModelsOptions): Promise { + const filename = getCacheFilename(options) + const cacheDir = getCacheDirectoryPathSync() + if (!cacheDir) { + return + } + const filePath = path.join(cacheDir, filename) + try { + await fs.unlink(filePath) + } catch (error) { + // File may not exist, which is fine + if ((error as NodeJS.ErrnoException).code !== "ENOENT") { + console.error(`[MODEL_CACHE] Error deleting cache file ${filename}:`, error) + } + } +} + /** * Fetch models from the provider API. * Extracted to avoid duplication between getModels() and refreshModels(). @@ -124,15 +206,17 @@ async function fetchModelsFromProvider(options: GetModelsOptions): Promise => { const { provider } = options + const cacheKey = getCacheKey(options) - let models = getModelsFromCache(provider) + let models = getModelsFromCacheWithKey(cacheKey) if (models) { return models @@ -145,10 +229,10 @@ export const getModels = async (options: GetModelsOptions): Promise // Only cache non-empty results to prevent persisting failed API responses // Empty results could indicate API failure rather than "no models exist" if (modelCount > 0) { - memoryCache.set(provider, models) + memoryCache.set(cacheKey, models) - await writeModels(provider, models).catch((err) => - console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err), + await writeModels(options, models).catch((err) => + console.error(`[MODEL_CACHE] Error writing ${cacheKey} models to file cache:`, err), ) } else { TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, { @@ -161,7 +245,7 @@ export const getModels = async (options: GetModelsOptions): Promise return models } catch (error) { // Log the error and re-throw it so the caller can handle it (e.g., show a UI message). - console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error) + console.error(`[getModels] Failed to fetch models in modelCache for ${cacheKey}:`, error) throw error // Re-throw the original error to be handled by the caller. } @@ -170,19 +254,23 @@ export const getModels = async (options: GetModelsOptions): Promise /** * Force-refresh models from API, bypassing cache. * Uses atomic writes so cache remains available during refresh. - * This function also prevents concurrent API calls for the same provider using + * This function also prevents concurrent API calls for the same provider+instance using * in-flight request tracking to avoid race conditions. * + * For multi-instance providers (LiteLLM, Ollama, etc.), uses unique cache keys that + * include a hash of the base URL to ensure different instances maintain separate caches. + * * @param options - Provider options for fetching models * @returns Fresh models from API, or existing cache if refresh yields worse data */ export const refreshModels = async (options: GetModelsOptions): Promise => { const { provider } = options + const cacheKey = getCacheKey(options) - // Check if there's already an in-flight refresh for this provider + // Check if there's already an in-flight refresh for this provider+instance // This prevents race conditions where multiple concurrent refreshes might // overwrite each other's results - const existingRequest = inFlightRefresh.get(provider) + const existingRequest = inFlightRefresh.get(cacheKey) if (existingRequest) { return existingRequest } @@ -190,12 +278,12 @@ export const refreshModels = async (options: GetModelsOptions): Promise => { try { - // Force fresh API fetch - skip getModelsFromCache() check + // Force fresh API fetch - skip getModelsFromCacheWithKey() check const models = await fetchModelsFromProvider(options) const modelCount = Object.keys(models).length // Get existing cached data for comparison - const existingCache = getModelsFromCache(provider) + const existingCache = getModelsFromCacheWithKey(cacheKey) const existingCount = existingCache ? Object.keys(existingCache).length : 0 if (modelCount === 0) { @@ -213,26 +301,26 @@ export const refreshModels = async (options: GetModelsOptions): Promise - console.error(`[refreshModels] Error writing ${provider} models to disk:`, err), + await writeModels(options, models).catch((err) => + console.error(`[refreshModels] Error writing ${cacheKey} models to disk:`, err), ) return models } catch (error) { // Log the error for debugging, then return existing cache if available (graceful degradation) - console.error(`[refreshModels] Failed to refresh ${provider} models:`, error) - return getModelsFromCache(provider) || {} + console.error(`[refreshModels] Failed to refresh ${cacheKey} models:`, error) + return getModelsFromCacheWithKey(cacheKey) || {} } finally { // Always clean up the in-flight tracking - inFlightRefresh.delete(provider) + inFlightRefresh.delete(cacheKey) } })() // Track the in-flight request - inFlightRefresh.set(provider, refreshPromise) + inFlightRefresh.set(cacheKey, refreshPromise) return refreshPromise } @@ -265,13 +353,14 @@ export async function initializeModelCacheRefresh(): Promise { } /** - * Flush models memory cache for a specific router. + * Flush models cache for a specific provider configuration. + * Clears both memory cache and disk cache for the unique cache key. * * @param options - The options for fetching models, including provider, apiKey, and baseUrl * @param refresh - If true, immediately fetch fresh data from API */ export const flushModels = async (options: GetModelsOptions, refresh: boolean = false): Promise => { - const { provider } = options + const cacheKey = getCacheKey(options) if (refresh) { // Don't delete memory cache - let refreshModels atomically replace it // This prevents a race condition where getModels() might be called @@ -279,22 +368,25 @@ export const flushModels = async (options: GetModelsOptions, refresh: boolean = // Await the refresh to ensure the cache is updated before returning await refreshModels(options) } else { - // Only delete memory cache when not refreshing - memoryCache.del(provider) + // Delete both memory and file cache when not refreshing + memoryCache.del(cacheKey) + await deleteModelsCacheFile(options).catch((err) => + console.error(`[flushModels] Error deleting ${cacheKey} cache file:`, err), + ) } } /** - * Get models from cache, checking memory first, then disk. + * Get models from cache using a unique cache key, checking memory first, then disk. * This ensures providers always have access to last known good data, * preventing fallback to hardcoded defaults on startup. * - * @param provider - The provider to get models for. + * @param cacheKey - The unique cache key (provider or provider_hash for multi-instance) * @returns Models from memory cache, disk cache, or undefined if not cached. */ -export function getModelsFromCache(provider: ProviderName): ModelRecord | undefined { +function getModelsFromCacheWithKey(cacheKey: string): ModelRecord | undefined { // Check memory cache first (fast) - const memoryModels = memoryCache.get(provider) + const memoryModels = memoryCache.get(cacheKey) if (memoryModels) { return memoryModels } @@ -302,7 +394,7 @@ export function getModelsFromCache(provider: ProviderName): ModelRecord | undefi // Memory cache miss - try to load from disk synchronously // This is acceptable because it only happens on cold start or after cache expiry try { - const filename = `${provider}_models.json` + const filename = `${cacheKey}_models.json` const cacheDir = getCacheDirectoryPathSync() if (!cacheDir) { return undefined @@ -320,24 +412,36 @@ export function getModelsFromCache(provider: ProviderName): ModelRecord | undefi const validation = modelRecordSchema.safeParse(models) if (!validation.success) { console.error( - `[MODEL_CACHE] Invalid disk cache data structure for ${provider}:`, + `[MODEL_CACHE] Invalid disk cache data structure for ${cacheKey}:`, validation.error.format(), ) return undefined } // Populate memory cache for future fast access - memoryCache.set(provider, validation.data) + memoryCache.set(cacheKey, validation.data) return validation.data } } catch (error) { - console.error(`[MODEL_CACHE] Error loading ${provider} models from disk:`, error) + console.error(`[MODEL_CACHE] Error loading ${cacheKey} models from disk:`, error) } return undefined } +/** + * Get models from cache for a provider. + * This is a convenience wrapper that uses the provider name as the cache key. + * For multi-instance providers, use getModels() instead which handles unique keys. + * + * @param provider - The provider to get models for. + * @returns Models from memory cache, disk cache, or undefined if not cached. + */ +export function getModelsFromCache(provider: ProviderName): ModelRecord | undefined { + return getModelsFromCacheWithKey(provider) +} + /** * Synchronous version of getCacheDirectoryPath for use in getModelsFromCache. * Returns the cache directory path without async operations.