diff --git a/Cargo.lock b/Cargo.lock index d1e342f9be..7d9dc1f278 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4033,6 +4033,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "htmd" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad1642def6e8e4dc182941f35454f7d2af917787f91f3f5133300030b41006d0" +dependencies = [ + "html5ever 0.27.0", + "markup5ever_rcdom", +] + [[package]] name = "html5ever" version = "0.26.0" @@ -4041,12 +4051,26 @@ checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" dependencies = [ "log", "mac", - "markup5ever", + "markup5ever 0.11.0", "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "html5ever" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c13771afe0e6e846f1e67d038d4cb29998a6779f93c809212e4e9c32efd244d4" +dependencies = [ + "log", + "mac", + "markup5ever 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "http" version = "0.2.12" @@ -4739,7 +4763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f29e4755b7b995046f510a7520c42b2fed58b77bd94d5a87a8eb43d2fd126da8" dependencies = [ "cssparser", - "html5ever", + "html5ever 0.26.0", "indexmap 1.9.3", "matches", "selectors", @@ -5024,6 +5048,32 @@ dependencies = [ "tendril", ] +[[package]] +name = "markup5ever" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16ce3abbeba692c8b8441d036ef91aea6df8da2c6b6e21c7e14d3c18e526be45" +dependencies = [ + "log", + "phf 0.11.3", + "phf_codegen 0.11.3", + "string_cache", + "string_cache_codegen", + "tendril", +] + +[[package]] +name = "markup5ever_rcdom" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edaa21ab3701bfee5099ade5f7e1f84553fd19228cf332f13cd6e964bf59be18" +dependencies = [ + "html5ever 0.27.0", + "markup5ever 0.12.1", + "tendril", + "xml5ever", +] + [[package]] name = "matchers" version = "0.1.0" @@ -6222,6 +6272,15 @@ dependencies = [ "phf_shared 0.10.0", ] +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared 0.11.3", +] + [[package]] name = "phf_codegen" version = "0.8.0" @@ -6242,6 +6301,16 @@ dependencies = [ "phf_shared 0.10.0", ] +[[package]] +name = "phf_codegen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +dependencies = [ + "phf_generator 0.11.3", + "phf_shared 0.11.3", +] + [[package]] name = "phf_generator" version = "0.8.0" @@ -6262,6 +6331,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared 0.11.3", + "rand 0.8.5", +] + [[package]] name = "phf_macros" version = "0.8.0" @@ -6282,7 +6361,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c00cf8b9eafe68dde5e9eaa2cef8ee84a9336a47d566ec55ca16589633b65af7" dependencies = [ - "siphasher", + "siphasher 0.3.11", ] [[package]] @@ -6291,7 +6370,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" dependencies = [ - "siphasher", + "siphasher 0.3.11", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher 1.0.1", ] [[package]] @@ -6827,11 +6915,13 @@ dependencies = [ "fig_auth", "fig_diagnostic", "fig_os_shim", + "fig_request", "fig_settings", "fig_telemetry", "fig_util", "futures", "glob", + "htmd", "mcp_client", "rand 0.9.0", "regex", @@ -8127,6 +8217,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "skim" version = "0.16.1" @@ -10398,7 +10494,7 @@ dependencies = [ "dunce", "gdkx11", "gtk", - "html5ever", + "html5ever 0.26.0", "http 1.2.0", "javascriptcore-rs", "jni", @@ -10492,6 +10588,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "xml5ever" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bbb26405d8e919bc1547a5aa9abc95cbfa438f04844f5fdd9dc7596b748bf69" +dependencies = [ + "log", + "mac", + "markup5ever 0.12.1", +] + [[package]] name = "xmlparser" version = "0.13.6" diff --git a/crates/q_chat/Cargo.toml b/crates/q_chat/Cargo.toml index ee38cde55b..f5949a3556 100644 --- a/crates/q_chat/Cargo.toml +++ b/crates/q_chat/Cargo.toml @@ -23,8 +23,10 @@ fig_os_shim.workspace = true fig_settings.workspace = true fig_telemetry.workspace = true fig_util.workspace = true +fig_request.workspace = true futures.workspace = true glob.workspace = true +htmd = "0.1" mcp_client.workspace = true rand.workspace = true regex.workspace = true diff --git a/crates/q_chat/src/tool_manager.rs b/crates/q_chat/src/tool_manager.rs index 4245cd1009..5e45b13373 100644 --- a/crates/q_chat/src/tool_manager.rs +++ b/crates/q_chat/src/tool_manager.rs @@ -57,6 +57,7 @@ use super::tools::{ }; use crate::tools::ToolSpec; use crate::tools::custom_tool::CustomTool; +use crate::tools::web_search::WebSearch; const NAMESPACE_DELIMITER: &str = "___"; // This applies for both mcp server and tool name since in the end the tool name as seen by the @@ -671,6 +672,7 @@ impl ToolManager { "execute_bash" => Tool::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), + "web_search" => Tool::WebSearch(serde_json::from_value::(value.args).map_err(map_err)?), // Note that this name is namespaced with server_name{DELIMITER}tool_name name => { let name = self.tn_map.get(name).map_or(name, String::as_str); diff --git a/crates/q_chat/src/tools/mod.rs b/crates/q_chat/src/tools/mod.rs index 279586736b..b6ff3ec848 100644 --- a/crates/q_chat/src/tools/mod.rs +++ b/crates/q_chat/src/tools/mod.rs @@ -4,6 +4,7 @@ pub mod fs_read; pub mod fs_write; pub mod gh_issue; pub mod use_aws; +pub mod web_search; use std::collections::HashMap; use std::io::Write; @@ -29,6 +30,7 @@ use serde::{ Serialize, }; use use_aws::UseAws; +use web_search::WebSearch; use super::consts::MAX_TOOL_RESPONSE_SIZE; @@ -41,6 +43,7 @@ pub enum Tool { UseAws(UseAws), Custom(CustomTool), GhIssue(GhIssue), + WebSearch(WebSearch), } impl Tool { @@ -53,6 +56,7 @@ impl Tool { Tool::UseAws(_) => "use_aws", Tool::Custom(custom_tool) => &custom_tool.name, Tool::GhIssue(_) => "gh_issue", + Tool::WebSearch(_) => "web_search", } .to_owned() } @@ -66,6 +70,7 @@ impl Tool { Tool::UseAws(use_aws) => use_aws.requires_acceptance(), Tool::Custom(_) => true, Tool::GhIssue(_) => false, + Tool::WebSearch(_) => false, } } @@ -78,6 +83,7 @@ impl Tool { Tool::UseAws(use_aws) => use_aws.invoke(context, updates).await, Tool::Custom(custom_tool) => custom_tool.invoke(context, updates).await, Tool::GhIssue(gh_issue) => gh_issue.invoke(updates).await, + Tool::WebSearch(web_search) => web_search.invoke(updates).await, } } @@ -90,6 +96,7 @@ impl Tool { Tool::UseAws(use_aws) => use_aws.queue_description(updates), Tool::Custom(custom_tool) => custom_tool.queue_description(updates), Tool::GhIssue(gh_issue) => gh_issue.queue_description(updates), + Tool::WebSearch(web_search) => web_search.queue_description(updates), } } @@ -102,6 +109,7 @@ impl Tool { Tool::UseAws(use_aws) => use_aws.validate(ctx).await, Tool::Custom(custom_tool) => custom_tool.validate(ctx).await, Tool::GhIssue(gh_issue) => gh_issue.validate(ctx).await, + Tool::WebSearch(web_search) => web_search.validate(ctx).await, } } } @@ -175,6 +183,7 @@ impl ToolPermissions { "execute_bash" => "trust read-only commands".dark_grey(), "use_aws" => "trust read-only commands".dark_grey(), "report_issue" => "trusted".dark_green().bold(), + "web_search" => "trusted".dark_green().bold(), _ => "not trusted".dark_grey(), }; diff --git a/crates/q_chat/src/tools/tool_index.json b/crates/q_chat/src/tools/tool_index.json index 397d856cfa..1c1c630824 100644 --- a/crates/q_chat/src/tools/tool_index.json +++ b/crates/q_chat/src/tools/tool_index.json @@ -147,6 +147,33 @@ ] } }, + "web_search": { + "name": "web_search", + "description": "Search/retrieving the web for the specified query. Currently only supports retrieving.", + "input_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to use. This is optional when mode is set to Retrieve since the target_url will be used instead." + }, + "mode": { + "type": "string", + "enum": [ + "Retrieve" + ], + "description": "Retrieve mode will return the markdown representation of the page. Search mode (not supported currently) will return the first x results from a search engine." + }, + "target_url": { + "type": "string", + "description": "The web page to retrieve. This is only used in Retrieve mode." + } + }, + "required": [ + "mode" + ] + } + }, "gh_issue": { "name": "report_issue", "description": "Opens the browser to a pre-filled gh (GitHub) issue template to report chat issues, bugs, or feature requests. Pre-filled information includes the conversation transcript, chat context, and chat request IDs from the service.", diff --git a/crates/q_chat/src/tools/web_search.rs b/crates/q_chat/src/tools/web_search.rs new file mode 100644 index 0000000000..9253129b06 --- /dev/null +++ b/crates/q_chat/src/tools/web_search.rs @@ -0,0 +1,192 @@ +use std::io::Write; + +use crossterm::queue; +use crossterm::style::{ + self, + Stylize, +}; +use eyre::Result; +use fig_os_shim::Context; +use fig_request::reqwest; +use htmd::HtmlToMarkdown; +use serde::Deserialize; + +use super::{ + InvokeOutput, + OutputKind, +}; + +#[derive(Debug, Clone, Deserialize)] +pub struct WebSearch { + pub query: Option, + pub mode: WebSearchMode, + pub target_url: Option, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub enum WebSearchMode { + Search, + Retrieve, +} + +impl WebSearch { + pub async fn invoke(&self, _updates: impl Write) -> Result { + let query = self.query.as_deref().unwrap_or(""); + let target_url = self.target_url.as_deref().unwrap_or(""); + + // Perform web search or retrieve based on the mode + match self.mode { + // TODO - need to align on what search engine to use + WebSearchMode::Search => { + if query.is_empty() { + return Err(eyre::eyre!("Query is required for web search")); + } + // Perform web search using the query + // ... + }, + WebSearchMode::Retrieve => { + if target_url.is_empty() { + return Err(eyre::eyre!("Target URL is required for retrieving")); + } + + // Parse the target URL to get the base domain for robots.txt + let parsed_url = + url::Url::parse(target_url).map_err(|e| eyre::eyre!("Failed to parse target URL: {}", e))?; + + // Construct robots.txt URL + let robots_url = format!( + "{}://{}/robots.txt", + parsed_url.scheme(), + parsed_url + .host_str() + .ok_or_else(|| eyre::eyre!("Invalid host in URL"))? + ); + + let user_agent = "AmazonQCLI/1.0"; + let client = reqwest::Client::new(); + + // Check robots.txt first + let robots_resp = client.get(&robots_url).send().await; + // If robots.txt exists, check if we're allowed to access + if let Ok(robots_resp) = robots_resp { + if robots_resp.status().is_success() { + let robots_content = robots_resp + .text() + .await + .map_err(|e| eyre::eyre!("Failed to read robots.txt: {}", e))?; + + // Simple robots.txt parsing + let path = parsed_url.path(); + if !Self::is_allowed_by_robots_txt(&robots_content, user_agent, path) { + return Err(eyre::eyre!("Access to this URL is disallowed by robots.txt")); + } + } + } + + // Send a GET request to the target URL with a custom User-Agent header + let response = client + .get(target_url) + .header(reqwest::header::USER_AGENT, user_agent) + .send() + .await + .map_err(|e| eyre::eyre!("Failed to connect to target URL: {}", e))?; + + // Check if the request was successful + if !response.status().is_success() { + return Err(eyre::eyre!("Request failed with status: {}", response.status())); + } + // Get the response body as text + let html_string = response + .text() + .await + .map_err(|e| eyre::eyre!("Failed to read response body: {}", e))?; + + // Convert HTML to Markdown + let converter = HtmlToMarkdown::builder().skip_tags(vec!["script", "style"]).build(); + + return Ok(InvokeOutput { + output: OutputKind::Json(serde_json::json!({ + "mkd_content": converter.convert(&html_string).unwrap(), + "target_url": target_url, + })), + }); + }, + } + + Ok(Default::default()) + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!( + updates, + style::Print(format!( + "{} {}...", + if self.mode == WebSearchMode::Search { + "Searching" + } else { + "Retrieving" + }, + if self.mode == WebSearchMode::Search { + self.query.as_ref().unwrap_or(&"".to_string()).clone().dark_green() + } else { + self.target_url.as_ref().unwrap_or(&"".to_string()).clone().dark_green() + } + )), + )?; + Ok(()) + } + + pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { + if self.mode == WebSearchMode::Search && self.query.is_none() { + return Err(eyre::eyre!("Query is required for web search")); + } + if self.mode == WebSearchMode::Retrieve && self.target_url.is_none() { + return Err(eyre::eyre!("Target URL is required for retrieving")); + } + + Ok(()) + } + + // Simple function to check if a path is allowed by robots.txt + fn is_allowed_by_robots_txt(robots_content: &str, user_agent: &str, path: &str) -> bool { + let mut current_agent; + let mut disallowed_paths = Vec::new(); + let mut is_relevant_agent = false; + + // Very basic robots.txt parser + for line in robots_content.lines() { + let line = line.trim(); + + // Skip comments and empty lines + if line.is_empty() || line.starts_with('#') { + continue; + } + + // Parse User-agent line + if let Some(agent) = line.strip_prefix("User-agent:") { + current_agent = agent.trim(); + is_relevant_agent = current_agent == "*" || current_agent == user_agent; + continue; + } + + // Parse Disallow line if it's for our user agent + if is_relevant_agent { + if let Some(disallow_path) = line.strip_prefix("Disallow:") { + let disallow_path = disallow_path.trim(); + if !disallow_path.is_empty() { + disallowed_paths.push(disallow_path); + } + } + } + } + + // Check if the path is disallowed + for disallow in &disallowed_paths { + if path.starts_with(disallow) || *disallow == "/" { + return false; + } + } + + true + } +}