-
Notifications
You must be signed in to change notification settings - Fork 5k
/
Copy pathllm.py
468 lines (394 loc) · 17.6 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
import os
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
import sys
import litellm
from . import __version__
litellm.suppress_debug_info = True
litellm.REPEATED_STREAMING_CHUNK_LIMIT = 99999999
import json
import logging
import subprocess
import time
import uuid
import requests
import tokentrim as tt
from .run_text_llm import run_text_llm
# from .run_function_calling_llm import run_function_calling_llm
from .run_tool_calling_llm import run_tool_calling_llm
from .utils.convert_to_openai_messages import convert_to_openai_messages
# Create or get the logger
logger = logging.getLogger("LiteLLM")
class SuppressDebugFilter(logging.Filter):
def filter(self, record):
# Suppress only the specific message containing the keywords
if "cost map" in record.getMessage():
return False # Suppress this log message
return True # Allow all other messages
class Llm:
"""
A stateless LMC-style LLM with some helpful properties.
"""
def __init__(self, interpreter):
# Add the filter to the logger
logger.addFilter(SuppressDebugFilter())
# Store a reference to parent interpreter
self.interpreter = interpreter
# OpenAI-compatible chat completions "endpoint"
self.completions = fixed_litellm_completions
# Settings
self.model = "gpt-4o"
self.temperature = 0
self.supports_vision = None # Will try to auto-detect
self.vision_renderer = (
self.interpreter.computer.vision.query
) # Will only use if supports_vision is False
self.supports_functions = None # Will try to auto-detect
self.execution_instructions = "To execute code on the user's machine, write a markdown code block. Specify the language after the ```. You will receive the output. Use any programming language." # If supports_functions is False, this will be added to the system message
# Optional settings
self.context_window = None
self.max_tokens = None
self.api_base = None
self.api_key = None
self.api_version = None
self._is_loaded = False
# Budget manager powered by LiteLLM
self.max_budget = None
def run(self, messages):
"""
We're responsible for formatting the call into the llm.completions object,
starting with LMC messages in interpreter.messages, going to OpenAI compatible messages into the llm,
respecting whether it's a vision or function model, respecting its context window and max tokens, etc.
And then processing its output, whether it's a function or non function calling model, into LMC format.
"""
if not self._is_loaded:
self.load()
if (
self.max_tokens is not None
and self.context_window is not None
and self.max_tokens > self.context_window
):
print(
"Warning: max_tokens is larger than context_window. Setting max_tokens to be 0.2 times the context_window."
)
self.max_tokens = int(0.2 * self.context_window)
# Assertions
assert (
messages[0]["role"] == "system"
), "First message must have the role 'system'"
for msg in messages[1:]:
assert (
msg["role"] != "system"
), "No message after the first can have the role 'system'"
model = self.model
if model in [
"claude-3.5",
"claude-3-5",
"claude-3.5-sonnet",
"claude-3-5-sonnet",
]:
model = "claude-3-5-sonnet-20240620"
self.model = "claude-3-5-sonnet-20240620"
# Setup our model endpoint
if model == "i":
model = "openai/i"
if not hasattr(self.interpreter, "conversation_id"): # Only do this once
self.context_window = 7000
self.api_key = "x"
self.max_tokens = 1000
self.api_base = "https://api.openinterpreter.com/v0"
self.interpreter.conversation_id = str(uuid.uuid4())
# Detect function support
if self.supports_functions == None:
try:
if litellm.supports_function_calling(model):
self.supports_functions = True
else:
self.supports_functions = False
except:
self.supports_functions = False
# Detect vision support
if self.supports_vision == None:
try:
if litellm.supports_vision(model):
self.supports_vision = True
else:
self.supports_vision = False
except:
self.supports_vision = False
# Trim image messages if they're there
image_messages = [msg for msg in messages if msg["type"] == "image"]
if self.supports_vision:
if self.interpreter.os:
# Keep only the last two images if the interpreter is running in OS mode
if len(image_messages) > 1:
for img_msg in image_messages[:-2]:
messages.remove(img_msg)
if self.interpreter.verbose:
print("Removing image message!")
else:
# Delete all the middle ones (leave only the first and last 2 images) from messages_for_llm
if len(image_messages) > 3:
for img_msg in image_messages[1:-2]:
messages.remove(img_msg)
if self.interpreter.verbose:
print("Removing image message!")
# Idea: we could set detail: low for the middle messages, instead of deleting them
elif self.supports_vision == False and self.vision_renderer:
for img_msg in image_messages:
if img_msg["format"] != "description":
self.interpreter.display_message("\n *Viewing image...*\n")
if img_msg["format"] == "path":
precursor = f"The image I'm referring to ({img_msg['content']}) contains the following: "
if self.interpreter.computer.import_computer_api:
postcursor = f"\nIf you want to ask questions about the image, run `computer.vision.query(path='{img_msg['content']}', query='(ask any question here)')` and a vision AI will answer it."
else:
postcursor = ""
else:
precursor = "Imagine I have just shown you an image with this description: "
postcursor = ""
try:
image_description = self.vision_renderer(lmc=img_msg)
ocr = self.interpreter.computer.vision.ocr(lmc=img_msg)
# It would be nice to format this as a message to the user and display it like: "I see: image_description"
img_msg["content"] = (
precursor
+ image_description
+ "\n---\nI've OCR'd the image, this is the result (this may or may not be relevant. If it's not relevant, ignore this): '''\n"
+ ocr
+ "\n'''"
+ postcursor
)
img_msg["format"] = "description"
except ImportError:
print(
"\nTo use local vision, run `pip install 'open-interpreter[local]'`.\n"
)
img_msg["format"] = "description"
img_msg["content"] = ""
# Convert to OpenAI messages format
messages = convert_to_openai_messages(
messages,
function_calling=self.supports_functions,
vision=self.supports_vision,
shrink_images=self.interpreter.shrink_images,
interpreter=self.interpreter,
)
system_message = messages[0]["content"]
messages = messages[1:]
# Trim messages
try:
if self.context_window and self.max_tokens:
trim_to_be_this_many_tokens = (
self.context_window - self.max_tokens - 25
) # arbitrary buffer
messages = tt.trim(
messages,
system_message=system_message,
max_tokens=trim_to_be_this_many_tokens,
)
elif self.context_window and not self.max_tokens:
# Just trim to the context window if max_tokens not set
messages = tt.trim(
messages,
system_message=system_message,
max_tokens=self.context_window,
)
else:
try:
messages = tt.trim(
messages, system_message=system_message, model=model
)
except:
if len(messages) == 1:
if self.interpreter.in_terminal_interface:
self.interpreter.display_message(
"""
**We were unable to determine the context window of this model.** Defaulting to 8000.
If your model can handle more, run `interpreter --context_window {token limit} --max_tokens {max tokens per response}`.
Continuing...
"""
)
else:
self.interpreter.display_message(
"""
**We were unable to determine the context window of this model.** Defaulting to 8000.
If your model can handle more, run `self.context_window = {token limit}`.
Also please set `self.max_tokens = {max tokens per response}`.
Continuing...
"""
)
messages = tt.trim(
messages, system_message=system_message, max_tokens=8000
)
except:
# If we're trimming messages, this won't work.
# If we're trimming from a model we don't know, this won't work.
# Better not to fail until `messages` is too big, just for frustrations sake, I suppose.
# Reunite system message with messages
messages = [{"role": "system", "content": system_message}] + messages
pass
# If there should be a system message, there should be a system message!
# Empty system messages appear to be deleted :(
if system_message == "":
if messages[0]["role"] != "system":
messages = [{"role": "system", "content": system_message}] + messages
## Start forming the request
params = {
"model": model,
"messages": messages,
"stream": True,
}
# Optional inputs
if self.api_key:
params["api_key"] = self.api_key
if self.api_base:
params["api_base"] = self.api_base
if self.api_version:
params["api_version"] = self.api_version
if self.max_tokens:
params["max_tokens"] = self.max_tokens
if self.temperature:
params["temperature"] = self.temperature
if hasattr(self.interpreter, "conversation_id"):
params["conversation_id"] = self.interpreter.conversation_id
# Set some params directly on LiteLLM
if self.max_budget:
litellm.max_budget = self.max_budget
if self.interpreter.verbose:
litellm.set_verbose = True
if (
self.interpreter.debug == True and False # DISABLED
): # debug will equal "server" if we're debugging the server specifically
print("\n\n\nOPENAI COMPATIBLE MESSAGES:\n\n\n")
for message in messages:
if len(str(message)) > 5000:
print(str(message)[:200] + "...")
else:
print(message)
print("\n")
print("\n\n\n")
if self.supports_functions:
# yield from run_function_calling_llm(self, params)
yield from run_tool_calling_llm(self, params)
else:
yield from run_text_llm(self, params)
# If you change model, set _is_loaded to false
@property
def model(self):
return self._model
@model.setter
def model(self, value):
self._model = value
self._is_loaded = False
def load(self):
if self._is_loaded:
return
if self.model.startswith("ollama/") and not ":" in self.model:
self.model = self.model + ":latest"
self._is_loaded = True
if self.model.startswith("ollama/"):
model_name = self.model.replace("ollama/", "")
api_base = getattr(self, "api_base", None) or os.getenv(
"OLLAMA_HOST", "http://localhost:11434"
)
names = []
try:
# List out all downloaded ollama models. Will fail if ollama isn't installed
response = requests.get(f"{api_base}/api/tags")
if response.ok:
data = response.json()
names = [
model["name"]
for model in data["models"]
if "name" in model and model["name"]
]
except Exception as e:
print(str(e))
self.interpreter.display_message(
f"> Ollama not found\n\nPlease download Ollama from [ollama.com](https://ollama.com/) to use `{model_name}`.\n"
)
exit()
# Download model if not already installed
if model_name not in names:
self.interpreter.display_message(f"\nDownloading {model_name}...\n")
requests.post(f"{api_base}/api/pull", json={"name": model_name})
# Get context window if not set
if self.context_window == None:
response = requests.post(
f"{api_base}/api/show", json={"name": model_name}
)
model_info = response.json().get("model_info", {})
context_length = None
for key in model_info:
if "context_length" in key:
context_length = model_info[key]
break
if context_length is not None:
self.context_window = context_length
if self.max_tokens == None:
if self.context_window != None:
self.max_tokens = int(self.context_window * 0.2)
# Send a ping, which will actually load the model
model_name = model_name.replace(":latest", "")
print(f"Loading {model_name}...\n")
old_max_tokens = self.max_tokens
self.max_tokens = 1
self.interpreter.computer.ai.chat("ping")
self.max_tokens = old_max_tokens
self.interpreter.display_message("*Model loaded.*\n")
# Validate LLM should be moved here!!
if self.context_window == None:
try:
model_info = litellm.get_model_info(model=self.model)
self.context_window = model_info["max_input_tokens"]
if self.max_tokens == None:
self.max_tokens = min(
int(self.context_window * 0.2), model_info["max_output_tokens"]
)
except:
pass
def fixed_litellm_completions(**params):
"""
Just uses a dummy API key, since we use litellm without an API key sometimes.
Hopefully they will fix this!
"""
if "local" in params.get("model"):
# Kinda hacky, but this helps sometimes
params["stop"] = ["<|assistant|>", "<|end|>", "<|eot_id|>"]
if params.get("model") == "i" and "conversation_id" in params:
litellm.drop_params = (
False # If we don't do this, litellm will drop this param!
)
else:
litellm.drop_params = True
params["model"] = params["model"].replace(":latest", "")
# Run completion
attempts = 4
first_error = None
params["num_retries"] = 0
params["headers"] = {"x-coding-assistant": f"open-interpreter/{__version__}"}
for attempt in range(attempts):
try:
yield from litellm.completion(**params)
return # If the completion is successful, exit the function
except KeyboardInterrupt:
print("Exiting...")
sys.exit(0)
except Exception as e:
if attempt == 0:
# Store the first error
first_error = e
if (
isinstance(e, litellm.exceptions.AuthenticationError)
and "api_key" not in params
):
print(
"LiteLLM requires an API key. Trying again with a dummy API key. In the future, if this fixes it, please set a dummy API key to prevent this message. (e.g `interpreter --api_key x` or `self.api_key = 'x'`)"
)
# So, let's try one more time with a dummy API key:
params["api_key"] = "x"
if attempt == 1:
# Try turning up the temperature?
params["temperature"] = params.get("temperature", 0.0) + 0.1
if first_error is not None:
raise first_error # If all attempts fail, raise the first error