From f943e34550d29562206e9c540c71afa5c23ad1f4 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Tue, 6 Jan 2026 22:51:10 +0000 Subject: [PATCH] fix: prevent model mixing between LiteLLM configuration profiles This commit fixes an issue where different LiteLLM instances with different base URLs would share the same model cache, causing model mixing when switching between configuration profiles. Changes: - Added getCacheKey() function that generates unique cache keys including a hash of the base URL for multi-instance providers (LiteLLM, Ollama, LM Studio, Requesty, DeepInfra, Roo) - Updated writeModels/readModels to use unique filenames per instance - Updated getModels/refreshModels to use unique cache keys - Updated flushModels to clear both memory and file cache for the instance - Added deleteModelsCacheFile() for cleaning up disk cache - Added comprehensive tests for multi-instance caching behavior Fixes #9169 --- .../fetchers/__tests__/modelCache.spec.ts | 226 +++++++++++++++++- src/api/providers/fetchers/modelCache.ts | 180 +++++++++++--- 2 files changed, 367 insertions(+), 39 deletions(-) diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 3c73b2a2725..1b95da39aac 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -59,7 +59,7 @@ vi.mock("../../../core/config/ContextProxy", () => ({ import type { Mock } from "vitest" import * as fsSync from "fs" import NodeCache from "node-cache" -import { getModels, getModelsFromCache } from "../modelCache" +import { getModels, getModelsFromCache, getCacheKey } from "../modelCache" import { getLiteLLMModels } from "../litellm" import { getOpenRouterModels } from "../openrouter" import { getRequestyModels } from "../requesty" @@ -474,3 +474,227 @@ describe("empty cache protection", () => { }) }) }) + +describe("multi-instance provider caching", () => { + describe("getCacheKey", () => { + it("returns provider name for non-multi-instance providers", () => { + expect(getCacheKey({ provider: "openrouter" })).toBe("openrouter") + expect(getCacheKey({ provider: "huggingface" })).toBe("huggingface") + expect(getCacheKey({ provider: "vercel-ai-gateway" })).toBe("vercel-ai-gateway") + }) + + it("returns provider name for multi-instance providers without baseUrl", () => { + expect(getCacheKey({ provider: "litellm", apiKey: "test-key", baseUrl: "" })).toBe("litellm") + expect(getCacheKey({ provider: "ollama" })).toBe("ollama") + expect(getCacheKey({ provider: "lmstudio" })).toBe("lmstudio") + }) + + it("returns unique key for multi-instance providers with baseUrl", () => { + const key1 = getCacheKey({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:4000", + }) + const key2 = getCacheKey({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:5000", + }) + + // Keys should be different for different URLs + expect(key1).not.toBe(key2) + // Keys should start with provider name + expect(key1).toMatch(/^litellm_/) + expect(key2).toMatch(/^litellm_/) + }) + + it("returns same key for same provider and baseUrl", () => { + const key1 = getCacheKey({ + provider: "litellm", + apiKey: "key1", + baseUrl: "http://localhost:4000", + }) + const key2 = getCacheKey({ + provider: "litellm", + apiKey: "different-key", // Different API key should not affect cache key + baseUrl: "http://localhost:4000", + }) + + expect(key1).toBe(key2) + }) + + it("generates unique keys for all multi-instance providers", () => { + const baseUrl = "http://localhost:8080" + + const litellmKey = getCacheKey({ provider: "litellm", apiKey: "key", baseUrl }) + const ollamaKey = getCacheKey({ provider: "ollama", baseUrl }) + const lmstudioKey = getCacheKey({ provider: "lmstudio", baseUrl }) + const requestyKey = getCacheKey({ provider: "requesty", baseUrl }) + const deepinfraKey = getCacheKey({ provider: "deepinfra", baseUrl }) + const rooKey = getCacheKey({ provider: "roo", baseUrl }) + + // All should have hash suffix + expect(litellmKey).toMatch(/^litellm_[a-f0-9]{8}$/) + expect(ollamaKey).toMatch(/^ollama_[a-f0-9]{8}$/) + expect(lmstudioKey).toMatch(/^lmstudio_[a-f0-9]{8}$/) + expect(requestyKey).toMatch(/^requesty_[a-f0-9]{8}$/) + expect(deepinfraKey).toMatch(/^deepinfra_[a-f0-9]{8}$/) + expect(rooKey).toMatch(/^roo_[a-f0-9]{8}$/) + }) + }) + + describe("getModels with different LiteLLM instances", () => { + let mockCache: any + let mockSet: Mock + + beforeEach(() => { + vi.clearAllMocks() + const MockedNodeCache = vi.mocked(NodeCache) + mockCache = new MockedNodeCache() + mockSet = mockCache.set + mockCache.get.mockReturnValue(undefined) + }) + + it("uses unique cache keys for different LiteLLM base URLs", async () => { + const modelsInstance1 = { + "model-a": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + description: "Model A", + }, + } + + mockGetLiteLLMModels.mockResolvedValue(modelsInstance1) + + await getModels({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:4000", + }) + + // Should cache with unique key including hash + expect(mockSet).toHaveBeenCalledWith(expect.stringMatching(/^litellm_[a-f0-9]{8}$/), modelsInstance1) + }) + + it("caches separately for different LiteLLM instances", async () => { + const modelsInstance1 = { + "model-instance-1": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + description: "Model from instance 1", + }, + } + const modelsInstance2 = { + "model-instance-2": { + maxTokens: 8192, + contextWindow: 200000, + supportsPromptCache: true, + description: "Model from instance 2", + }, + } + + // First instance returns models + mockGetLiteLLMModels.mockResolvedValueOnce(modelsInstance1) + await getModels({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:4000", + }) + + // Second instance returns different models + mockGetLiteLLMModels.mockResolvedValueOnce(modelsInstance2) + await getModels({ + provider: "litellm", + apiKey: "test-key", + baseUrl: "http://localhost:5000", + }) + + // Should have been called twice with different cache keys + expect(mockSet).toHaveBeenCalledTimes(2) + + // Get the cache keys used + const cacheKey1 = mockSet.mock.calls[0][0] + const cacheKey2 = mockSet.mock.calls[1][0] + + // Keys should be different + expect(cacheKey1).not.toBe(cacheKey2) + + // Both should be litellm keys + expect(cacheKey1).toMatch(/^litellm_/) + expect(cacheKey2).toMatch(/^litellm_/) + }) + }) + + describe("refreshModels with different instances", () => { + let mockCache: any + let mockSet: Mock + + beforeEach(() => { + vi.clearAllMocks() + const MockedNodeCache = vi.mocked(NodeCache) + mockCache = new MockedNodeCache() + mockSet = mockCache.set + mockCache.get.mockReturnValue(undefined) + }) + + it("tracks in-flight requests separately for different instances", async () => { + const models1 = { + "model-1": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + }, + } + const models2 = { + "model-2": { + maxTokens: 8192, + contextWindow: 200000, + supportsPromptCache: false, + }, + } + + // Create delayed responses to simulate API latency + let resolve1: (value: typeof models1) => void + let resolve2: (value: typeof models2) => void + const promise1 = new Promise((resolve) => { + resolve1 = resolve + }) + const promise2 = new Promise((resolve) => { + resolve2 = resolve + }) + + mockGetLiteLLMModels + .mockReturnValueOnce(promise1) + .mockReturnValueOnce(promise2) + + const { refreshModels } = await import("../modelCache") + + // Start concurrent refreshes for different instances + const refresh1 = refreshModels({ + provider: "litellm", + apiKey: "key", + baseUrl: "http://instance1:4000", + }) + const refresh2 = refreshModels({ + provider: "litellm", + apiKey: "key", + baseUrl: "http://instance2:5000", + }) + + // Both should call the API since they are different instances + expect(mockGetLiteLLMModels).toHaveBeenCalledTimes(2) + + // Resolve both + resolve1!(models1) + resolve2!(models2) + + const [result1, result2] = await Promise.all([refresh1, refresh2]) + + // Results should be different + expect(result1).toEqual(models1) + expect(result2).toEqual(models2) + }) + }) +}) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index d22abf9c91c..dd52e321357 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -35,24 +35,106 @@ const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) // Zod schema for validating ModelRecord structure from disk cache const modelRecordSchema = z.record(z.string(), modelInfoSchema) -// Track in-flight refresh requests to prevent concurrent API calls for the same provider +// Track in-flight refresh requests to prevent concurrent API calls for the same provider+instance // This prevents race conditions where multiple calls might overwrite each other's results -const inFlightRefresh = new Map>() +// Uses unique cache keys that include base URL hash for multi-instance providers +const inFlightRefresh = new Map>() + +// Providers that can have multiple instances with different base URLs +// These need unique cache keys to prevent model mixing between configurations +const multiInstanceProviders = new Set([ + "litellm", + "ollama", + "lmstudio", + "requesty", + "deepinfra", + "roo", +]) -async function writeModels(router: RouterName, data: ModelRecord) { - const filename = `${router}_models.json` +/** + * Simple hash function for generating short, deterministic hashes from strings. + * Uses djb2 algorithm which produces good distribution for URL-like strings. + * @param str - String to hash + * @returns 8-character hex hash + */ +function simpleHash(str: string): string { + let hash = 5381 + for (let i = 0; i < str.length; i++) { + hash = (hash * 33) ^ str.charCodeAt(i) + } + // Convert to unsigned 32-bit integer and then to hex + return (hash >>> 0).toString(16).padStart(8, "0") +} + +/** + * Generate a unique cache key for a provider based on its configuration. + * For multi-instance providers (LiteLLM, Ollama, etc.), includes a hash of the + * base URL to ensure different instances maintain separate caches. + * + * @param options - Provider options including provider name and optional baseUrl + * @returns Unique cache key string + */ +export function getCacheKey(options: GetModelsOptions): string { + const { provider } = options + const baseUrl = "baseUrl" in options ? options.baseUrl : undefined + + // For multi-instance providers with a base URL, include a hash of the URL + if (multiInstanceProviders.has(provider) && baseUrl) { + const urlHash = simpleHash(baseUrl) + return `${provider}_${urlHash}` + } + + return provider +} + +/** + * Generate a unique filename for disk cache based on provider configuration. + * Uses the same unique key logic as getCacheKey. + * + * @param options - Provider options including provider name and optional baseUrl + * @returns Filename for the cache file (without path) + */ +function getCacheFilename(options: GetModelsOptions): string { + return `${getCacheKey(options)}_models.json` +} + +async function writeModels(options: GetModelsOptions, data: ModelRecord) { + const filename = getCacheFilename(options) const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) await safeWriteJson(path.join(cacheDir, filename), data) } -async function readModels(router: RouterName): Promise { - const filename = `${router}_models.json` +async function readModels(options: GetModelsOptions): Promise { + const filename = getCacheFilename(options) const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) const filePath = path.join(cacheDir, filename) const exists = await fileExistsAtPath(filePath) return exists ? JSON.parse(await fs.readFile(filePath, "utf8")) : undefined } +/** + * Delete the disk cache file for a specific provider configuration. + * Used during flush to ensure both memory and disk caches are cleared. + * + * @param options - Provider options including provider name and optional baseUrl + */ +async function deleteModelsCacheFile(options: GetModelsOptions): Promise { + const filename = getCacheFilename(options) + const cacheDir = getCacheDirectoryPathSync() + if (!cacheDir) { + return + } + const filePath = path.join(cacheDir, filename) + try { + await fs.unlink(filePath) + } catch (error) { + // File may not exist, which is fine + if ((error as NodeJS.ErrnoException).code !== "ENOENT") { + console.error(`[MODEL_CACHE] Error deleting cache file ${filename}:`, error) + } + } +} + /** * Fetch models from the provider API. * Extracted to avoid duplication between getModels() and refreshModels(). @@ -124,15 +206,17 @@ async function fetchModelsFromProvider(options: GetModelsOptions): Promise => { const { provider } = options + const cacheKey = getCacheKey(options) - let models = getModelsFromCache(provider) + let models = getModelsFromCacheWithKey(cacheKey) if (models) { return models @@ -145,10 +229,10 @@ export const getModels = async (options: GetModelsOptions): Promise // Only cache non-empty results to prevent persisting failed API responses // Empty results could indicate API failure rather than "no models exist" if (modelCount > 0) { - memoryCache.set(provider, models) + memoryCache.set(cacheKey, models) - await writeModels(provider, models).catch((err) => - console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err), + await writeModels(options, models).catch((err) => + console.error(`[MODEL_CACHE] Error writing ${cacheKey} models to file cache:`, err), ) } else { TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, { @@ -161,7 +245,7 @@ export const getModels = async (options: GetModelsOptions): Promise return models } catch (error) { // Log the error and re-throw it so the caller can handle it (e.g., show a UI message). - console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error) + console.error(`[getModels] Failed to fetch models in modelCache for ${cacheKey}:`, error) throw error // Re-throw the original error to be handled by the caller. } @@ -170,19 +254,23 @@ export const getModels = async (options: GetModelsOptions): Promise /** * Force-refresh models from API, bypassing cache. * Uses atomic writes so cache remains available during refresh. - * This function also prevents concurrent API calls for the same provider using + * This function also prevents concurrent API calls for the same provider+instance using * in-flight request tracking to avoid race conditions. * + * For multi-instance providers (LiteLLM, Ollama, etc.), uses unique cache keys that + * include a hash of the base URL to ensure different instances maintain separate caches. + * * @param options - Provider options for fetching models * @returns Fresh models from API, or existing cache if refresh yields worse data */ export const refreshModels = async (options: GetModelsOptions): Promise => { const { provider } = options + const cacheKey = getCacheKey(options) - // Check if there's already an in-flight refresh for this provider + // Check if there's already an in-flight refresh for this provider+instance // This prevents race conditions where multiple concurrent refreshes might // overwrite each other's results - const existingRequest = inFlightRefresh.get(provider) + const existingRequest = inFlightRefresh.get(cacheKey) if (existingRequest) { return existingRequest } @@ -190,12 +278,12 @@ export const refreshModels = async (options: GetModelsOptions): Promise => { try { - // Force fresh API fetch - skip getModelsFromCache() check + // Force fresh API fetch - skip getModelsFromCacheWithKey() check const models = await fetchModelsFromProvider(options) const modelCount = Object.keys(models).length // Get existing cached data for comparison - const existingCache = getModelsFromCache(provider) + const existingCache = getModelsFromCacheWithKey(cacheKey) const existingCount = existingCache ? Object.keys(existingCache).length : 0 if (modelCount === 0) { @@ -213,26 +301,26 @@ export const refreshModels = async (options: GetModelsOptions): Promise - console.error(`[refreshModels] Error writing ${provider} models to disk:`, err), + await writeModels(options, models).catch((err) => + console.error(`[refreshModels] Error writing ${cacheKey} models to disk:`, err), ) return models } catch (error) { // Log the error for debugging, then return existing cache if available (graceful degradation) - console.error(`[refreshModels] Failed to refresh ${provider} models:`, error) - return getModelsFromCache(provider) || {} + console.error(`[refreshModels] Failed to refresh ${cacheKey} models:`, error) + return getModelsFromCacheWithKey(cacheKey) || {} } finally { // Always clean up the in-flight tracking - inFlightRefresh.delete(provider) + inFlightRefresh.delete(cacheKey) } })() // Track the in-flight request - inFlightRefresh.set(provider, refreshPromise) + inFlightRefresh.set(cacheKey, refreshPromise) return refreshPromise } @@ -265,13 +353,14 @@ export async function initializeModelCacheRefresh(): Promise { } /** - * Flush models memory cache for a specific router. + * Flush models cache for a specific provider configuration. + * Clears both memory cache and disk cache for the unique cache key. * * @param options - The options for fetching models, including provider, apiKey, and baseUrl * @param refresh - If true, immediately fetch fresh data from API */ export const flushModels = async (options: GetModelsOptions, refresh: boolean = false): Promise => { - const { provider } = options + const cacheKey = getCacheKey(options) if (refresh) { // Don't delete memory cache - let refreshModels atomically replace it // This prevents a race condition where getModels() might be called @@ -279,22 +368,25 @@ export const flushModels = async (options: GetModelsOptions, refresh: boolean = // Await the refresh to ensure the cache is updated before returning await refreshModels(options) } else { - // Only delete memory cache when not refreshing - memoryCache.del(provider) + // Delete both memory and file cache when not refreshing + memoryCache.del(cacheKey) + await deleteModelsCacheFile(options).catch((err) => + console.error(`[flushModels] Error deleting ${cacheKey} cache file:`, err), + ) } } /** - * Get models from cache, checking memory first, then disk. + * Get models from cache using a unique cache key, checking memory first, then disk. * This ensures providers always have access to last known good data, * preventing fallback to hardcoded defaults on startup. * - * @param provider - The provider to get models for. + * @param cacheKey - The unique cache key (provider or provider_hash for multi-instance) * @returns Models from memory cache, disk cache, or undefined if not cached. */ -export function getModelsFromCache(provider: ProviderName): ModelRecord | undefined { +function getModelsFromCacheWithKey(cacheKey: string): ModelRecord | undefined { // Check memory cache first (fast) - const memoryModels = memoryCache.get(provider) + const memoryModels = memoryCache.get(cacheKey) if (memoryModels) { return memoryModels } @@ -302,7 +394,7 @@ export function getModelsFromCache(provider: ProviderName): ModelRecord | undefi // Memory cache miss - try to load from disk synchronously // This is acceptable because it only happens on cold start or after cache expiry try { - const filename = `${provider}_models.json` + const filename = `${cacheKey}_models.json` const cacheDir = getCacheDirectoryPathSync() if (!cacheDir) { return undefined @@ -320,24 +412,36 @@ export function getModelsFromCache(provider: ProviderName): ModelRecord | undefi const validation = modelRecordSchema.safeParse(models) if (!validation.success) { console.error( - `[MODEL_CACHE] Invalid disk cache data structure for ${provider}:`, + `[MODEL_CACHE] Invalid disk cache data structure for ${cacheKey}:`, validation.error.format(), ) return undefined } // Populate memory cache for future fast access - memoryCache.set(provider, validation.data) + memoryCache.set(cacheKey, validation.data) return validation.data } } catch (error) { - console.error(`[MODEL_CACHE] Error loading ${provider} models from disk:`, error) + console.error(`[MODEL_CACHE] Error loading ${cacheKey} models from disk:`, error) } return undefined } +/** + * Get models from cache for a provider. + * This is a convenience wrapper that uses the provider name as the cache key. + * For multi-instance providers, use getModels() instead which handles unique keys. + * + * @param provider - The provider to get models for. + * @returns Models from memory cache, disk cache, or undefined if not cached. + */ +export function getModelsFromCache(provider: ProviderName): ModelRecord | undefined { + return getModelsFromCacheWithKey(provider) +} + /** * Synchronous version of getCacheDirectoryPath for use in getModelsFromCache. * Returns the cache directory path without async operations.