Skip to content

Commit c6eea98

Browse files
committed
fix BerriAI#9692. Keep cache key stable during mutation
A) Return a copy from strict key removal to not break cache keys B) Fix issue in existing cache key stabilizer that was not storing a stable key across request/response if no litellm_params existed in the request
1 parent 655ce2e commit c6eea98

File tree

5 files changed

+150
-10
lines changed

5 files changed

+150
-10
lines changed

litellm/caching/caching.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,9 @@ def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
330330
"""
331331
Get the preset cache key from kwargs["litellm_params"]
332332
333-
We use _get_preset_cache_keys for two reasons
333+
Is set after the cache is first calculated in order to not mutate between request and response time,
334+
in case the implementation mutates the original objects (and avoids doing duplicate key calculations)
334335
335-
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
336-
2. avoid doing duplicate / repeated work
337336
"""
338337
if kwargs:
339338
if "litellm_params" in kwargs:
@@ -346,7 +345,10 @@ def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> No
346345
347346
This is used to avoid doing duplicate / repeated work
348347
349-
Placed in kwargs["litellm_params"]
348+
Placed in kwargs["litellm_params"].
349+
350+
Note! Your request must have a `litellm_params` key in order to use this feature,
351+
(as mutating the **kwargs splat object here does not mutate the original reference object).
350352
"""
351353
if kwargs:
352354
if "litellm_params" in kwargs:

litellm/utils.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,10 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915
12511251
if "litellm_call_id" not in kwargs:
12521252
kwargs["litellm_call_id"] = str(uuid.uuid4())
12531253

1254+
# set up litellm_params, so that keys can be added (e.g. for tracking cache keys)
1255+
if "litellm_params" not in kwargs:
1256+
kwargs["litellm_params"] = {}
1257+
12541258
model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
12551259
is_completion_with_fallbacks = kwargs.get("fallbacks") is not None
12561260

@@ -2770,23 +2774,33 @@ def _remove_additional_properties(schema):
27702774

27712775
def _remove_strict_from_schema(schema):
27722776
"""
2773-
Relevant Issues: https://github.com/BerriAI/litellm/issues/6136, https://github.com/BerriAI/litellm/issues/6088
2777+
Recursively removes 'strict' from schema. Returns a copy, in order to not break cache keys, (so you should update your reference)
2778+
2779+
Relevant Issues: https://github.com/BerriAI/litellm/issues/6136, https://github.com/BerriAI/litellm/issues/6088,
27742780
"""
2781+
maybe_copy = None # make a copy to not break cache keys https://github.com/BerriAI/litellm/issues/9692
27752782
if isinstance(schema, dict):
27762783
# Remove the 'additionalProperties' key if it exists and is set to False
27772784
if "strict" in schema:
2778-
del schema["strict"]
2785+
maybe_copy = schema.copy()
2786+
del maybe_copy["strict"]
27792787

27802788
# Recursively process all dictionary values
27812789
for key, value in schema.items():
2782-
_remove_strict_from_schema(value)
2790+
result = _remove_strict_from_schema(value)
2791+
if result is not value:
2792+
maybe_copy = maybe_copy or schema.copy()
2793+
maybe_copy[key] = result
27832794

27842795
elif isinstance(schema, list):
27852796
# Recursively process all items in the list
2786-
for item in schema:
2787-
_remove_strict_from_schema(item)
2797+
for i, item in enumerate(schema):
2798+
result = _remove_strict_from_schema(item)
2799+
if result is not item:
2800+
maybe_copy = maybe_copy or list(schema)
2801+
maybe_copy[i] = result
27882802

2789-
return schema
2803+
return maybe_copy or schema
27902804

27912805

27922806
def _remove_unsupported_params(

tests/litellm_utils_tests/test_utils.py

+105
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import sys
34
import time
45
from datetime import datetime
@@ -1890,6 +1891,43 @@ async def test_function(**kwargs):
18901891
== "gpt-4o-mini"
18911892
)
18921893

1894+
@pytest.mark.asyncio
1895+
async def test_cache_key_stability_with_mutation(monkeypatch):
1896+
from litellm.utils import client
1897+
import asyncio
1898+
from litellm.caching import Cache
1899+
1900+
# Set up in-memory cache
1901+
cache = Cache()
1902+
monkeypatch.setattr(litellm, "cache", cache)
1903+
1904+
# Create mock original function
1905+
mock_original = AsyncMock()
1906+
1907+
def side_effect(**kwargs):
1908+
print(f"kwargs: {kwargs}")
1909+
return litellm.ModelResponse(
1910+
model="vertex_ai/gemini-2.0-flash"
1911+
)
1912+
mock_original.side_effect = side_effect
1913+
1914+
# Apply decorator
1915+
@client
1916+
async def acompletion(**kwargs):
1917+
kwargs["messages"][0]["content"] = "mutated"
1918+
return await mock_original(**kwargs)
1919+
1920+
# Test kwargs
1921+
test_kwargs = {"model": "vertex_ai/gemini-2.0-flash", "messages": [{"role": "user", "content": "Hello, world!"}]}
1922+
original_kwargs = copy.deepcopy(test_kwargs)
1923+
1924+
# Call decorated function
1925+
await acompletion(**test_kwargs)
1926+
await asyncio.sleep(0.01)
1927+
await acompletion(**original_kwargs)
1928+
1929+
mock_original.assert_called_once()
1930+
18931931

18941932
def test_dict_to_response_format_helper():
18951933
from litellm.llms.base_llm.base_utils import _dict_to_response_format_helper
@@ -2102,3 +2140,70 @@ def test_get_provider_audio_transcription_config():
21022140
config = ProviderConfigManager.get_provider_audio_transcription_config(
21032141
model="whisper-1", provider=provider
21042142
)
2143+
2144+
def test_remove_strict_from_schema():
2145+
from litellm.utils import _remove_strict_from_schema
2146+
2147+
schema = { # This isn't maybe actually very realistic json schema, just slop full of stricts
2148+
"$schema": "http://json-schema.org/draft-07/schema#",
2149+
"type": "object",
2150+
"strict": True,
2151+
"definitions": {
2152+
"address": {
2153+
"type": "object",
2154+
"properties": {
2155+
"street": {"type": "string"},
2156+
"city": {"type": "string"}
2157+
},
2158+
"required": ["street", "city"],
2159+
"strict": True
2160+
}
2161+
},
2162+
"properties": {
2163+
"name": {
2164+
"type": "string",
2165+
"strict": True
2166+
},
2167+
"age": {
2168+
"type": "integer"
2169+
},
2170+
"address": {
2171+
"$ref": "#/definitions/address"
2172+
},
2173+
"tags": {
2174+
"type": "array",
2175+
"items": {"type": "string"},
2176+
"strict": True
2177+
},
2178+
"contacts": {
2179+
"type": "array",
2180+
"items": {
2181+
"oneOf": [
2182+
{"type": "string"},
2183+
{
2184+
"type": "array",
2185+
"items": {
2186+
"type": "object",
2187+
"strict": True,
2188+
"properties": {
2189+
"value": {"type": "string"}
2190+
},
2191+
"required": ["value"]
2192+
}
2193+
}
2194+
],
2195+
"strict": True
2196+
}
2197+
}
2198+
}
2199+
}
2200+
original_schema = copy.deepcopy(schema)
2201+
cleaned = _remove_strict_from_schema(schema)
2202+
assert "strict" not in json.dumps(cleaned)
2203+
# schema should be unchanged, (should copy instead of mutate)
2204+
# otherwise it breaks cache keys
2205+
# https://github.com/BerriAI/litellm/issues/9692
2206+
assert cleaned != original_schema
2207+
assert schema == original_schema
2208+
2209+

tests/local_testing/test_caching.py

+1
Original file line numberDiff line numberDiff line change
@@ -2608,3 +2608,4 @@ def test_caching_with_reasoning_content():
26082608
print(f"response 2: {response_2.model_dump_json(indent=4)}")
26092609
assert response_2._hidden_params["cache_hit"] == True
26102610
assert response_2.choices[0].message.reasoning_content is not None
2611+

tests/local_testing/test_unit_test_caching.py

+18
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,21 @@ def test_generate_streaming_content():
251251
assert chunk_count > 1
252252

253253
print(f"Number of chunks: {chunk_count}")
254+
255+
def test_caching_stable_with_mutation():
256+
"""
257+
Test that caching is stable with mutation
258+
"""
259+
litellm.cache = Cache()
260+
kwargs = {
261+
"model": "gpt-3.5-turbo",
262+
"messages": [{"role": "user", "content": "Hello, world!"}],
263+
"temperature": 0.7,
264+
"litellm_params": {},
265+
}
266+
cache_key = litellm.cache.get_cache_key(**kwargs)
267+
268+
# mutate kwargs
269+
kwargs["temperature"] = 0.8
270+
cache_key_2 = litellm.cache.get_cache_key(**kwargs)
271+
assert cache_key == cache_key_2

0 commit comments

Comments
 (0)