diff --git a/src/utils/utils.py b/src/utils/utils.py index 2590a0bf..18cb436c 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -182,22 +182,13 @@ def get_llm_model(provider: str, **kwargs): # Callback to update the model name dropdown based on the selected provider -def update_model_dropdown(llm_provider, api_key=None, base_url=None): - """ - Update the model name dropdown with predefined models for the selected provider. - """ - import gradio as gr - # Use API keys from .env if not provided - if not api_key: - api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "") - if not base_url: - base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "") - - # Use predefined models for the selected provider - if llm_provider in model_names: - return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True) - else: - return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True) +def update_model_dropdown(selected_provider, current_model): + """Update model choices based on provider.""" + choices = model_names[selected_provider] + + if current_model not in choices and any(current_model in models for models in model_names.values()): + current_model = choices[0] + return gr.update(choices=choices, value=current_model) class MissingAPIKeyError(Exception): diff --git a/webui.py b/webui.py index bc686055..95f411ae 100644 --- a/webui.py +++ b/webui.py @@ -14,7 +14,7 @@ import gradio as gr import inspect -from functools import wraps +from functools import partial, wraps from browser_use.agent.service import Agent from playwright.async_api import async_playwright @@ -49,37 +49,59 @@ webui_config_manager = utils.ConfigManager() +def sync_component_value_to_manager(component_registered_name, new_value): + """Sync the value of a component to the config manager""" + global webui_config_manager + if webui_config_manager: + component_object = webui_config_manager.components.get(component_registered_name) + if component_object: + current_manager_value = getattr(component_object, "value", None) + if current_manager_value != new_value: + component_object.value = new_value + return None + def scan_and_register_components(blocks): """扫描一个 Blocks 对象并注册其中的所有交互式组件,但不包括按钮""" global webui_config_manager + component_map = {} def traverse_blocks(block, prefix=""): + nonlocal component_map registered = 0 - # 处理 Blocks 自身的组件 if hasattr(block, "children"): for i, child in enumerate(block.children): + name = None + is_eligible_for_config = False if isinstance(child, gr.components.Component): - # 排除按钮 (Button) 组件 - if getattr(child, "interactive", False) and not isinstance(child, gr.Button): + # 排除按钮 (Button/File) 组件 + if getattr(child, "interactive", False) and not isinstance(child, gr.Button) and not isinstance(child, gr.File): + is_eligible_for_config = True name = f"{prefix}component_{i}" if hasattr(child, "label") and child.label: # 使用标签作为名称的一部分 label = child.label name = f"{prefix}{label}" - logger.debug(f"Registering component: {name}") - webui_config_manager.register_component(name, child) - registered += 1 elif hasattr(child, "children"): # 递归处理嵌套的 Blocks new_prefix = f"{prefix}block_{i}_" registered += traverse_blocks(child, new_prefix) + if is_eligible_for_config and name: + webui_config_manager.register_component(name, child) + component_map[name] = child + registered += 1 return registered total = traverse_blocks(blocks) logger.info(f"Total registered components: {total}") + # Register the components with the config manager + for name, component_obj in component_map.items(): + sync_handler = partial(sync_component_value_to_manager, name) + if hasattr(component_obj, 'change'): + component_obj.change(fn=sync_handler, inputs=[component_obj], outputs=None) + def save_current_config(): return webui_config_manager.save_current_config() @@ -1158,8 +1180,8 @@ def list_recordings(save_recording_path): # Attach the callback to the LLM provider dropdown llm_provider.change( - lambda provider, api_key, base_url: update_model_dropdown(provider, api_key, base_url), - inputs=[llm_provider, llm_api_key, llm_base_url], + fn=utils.update_model_dropdown, + inputs=[llm_provider, llm_model_name], outputs=llm_model_name )