diff --git a/README.md b/README.md index 55d60a7f..4275d243 100644 --- a/README.md +++ b/README.md @@ -417,10 +417,10 @@ Custom providers can implement these methods: embed?: string|function, -- Optional: Get extra request headers with optional expiration time - get_headers?(): table, number?, + get_headers?(self: CopilotChat.Provider): table, number?, -- Optional: Get API endpoint URL - get_url?(opts: CopilotChat.Provider.options): string, + get_url?(self: CopilotChat.Provider, opts: CopilotChat.Provider.options): string, -- Optional: Prepare request input prepare_input?(inputs: table, opts: CopilotChat.Provider.options): table, @@ -429,10 +429,10 @@ Custom providers can implement these methods: prepare_output?(output: table, opts: CopilotChat.Provider.options): CopilotChat.Provider.output, -- Optional: Get available models - get_models?(headers: table): table, + get_models?(self: CopilotChat.Provider, headers: table): table, -- Optional: Get available agents - get_agents?(headers: table): table, + get_agents?(self: CopilotChat.Provider, headers: table): table, } ``` diff --git a/doc/CopilotChat.txt b/doc/CopilotChat.txt index 9512331d..0316b5a3 100644 --- a/doc/CopilotChat.txt +++ b/doc/CopilotChat.txt @@ -470,10 +470,10 @@ Custom providers can implement these methods: embed?: string|function, -- Optional: Get extra request headers with optional expiration time - get_headers?(): table, number?, + get_headers?(self: CopilotChat.Provider): table, number?, -- Optional: Get API endpoint URL - get_url?(opts: CopilotChat.Provider.options): string, + get_url?(self: CopilotChat.Provider, opts: CopilotChat.Provider.options): string, -- Optional: Prepare request input prepare_input?(inputs: table, opts: CopilotChat.Provider.options): table, @@ -482,10 +482,10 @@ Custom providers can implement these methods: prepare_output?(output: table, opts: CopilotChat.Provider.options): CopilotChat.Provider.output, -- Optional: Get available models - get_models?(headers: table): table, + get_models?(self: CopilotChat.Provider, headers: table): table, -- Optional: Get available agents - get_agents?(headers: table): table, + get_agents?(self: CopilotChat.Provider, headers: table): table, } < diff --git a/lua/CopilotChat/client.lua b/lua/CopilotChat/client.lua index dddd1b89..6d38c521 100644 --- a/lua/CopilotChat/client.lua +++ b/lua/CopilotChat/client.lua @@ -327,7 +327,7 @@ function Client:authenticate(provider_name) local expires_at = self.provider_cache[provider_name].expires_at if provider.get_headers and (not headers or (expires_at and expires_at <= math.floor(os.time()))) then - headers, expires_at = provider.get_headers() + headers, expires_at = provider:get_headers() self.provider_cache[provider_name].headers = headers self.provider_cache[provider_name].expires_at = expires_at end @@ -354,7 +354,7 @@ function Client:fetch_models() log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) goto continue end - local ok, provider_models = pcall(provider.get_models, headers) + local ok, provider_models = pcall(provider.get_models, provider, headers) if not ok then log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. provider_models) goto continue @@ -396,7 +396,7 @@ function Client:fetch_agents() log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) goto continue end - local ok, provider_agents = pcall(provider.get_agents, headers) + local ok, provider_agents = pcall(provider.get_agents, provider, headers) if not ok then log.warn('Failed to fetch agents from ' .. provider_name .. ': ' .. provider_agents) goto continue @@ -671,7 +671,7 @@ function Client:ask(prompt, opts) args.stream = stream_func end - local response, err = utils.curl_post(provider.get_url(options), args) + local response, err = utils.curl_post(provider:get_url(options), args) if not opts.headless then if self.current_job ~= job_id then @@ -815,7 +815,12 @@ function Client:embed(inputs, model) local success = false local attempts = 0 while not success and attempts < 5 do -- Limit total attempts to 5 - local ok, data = pcall(embed, generate_embedding_request(batch, threshold), self:authenticate(provider_name)) + local ok, data = pcall( + embed, + self.providers[models[model].provider], + generate_embedding_request(batch, threshold), + self:authenticate(provider_name) + ) if not ok then log.debug('Failed to get embeddings: ', data) diff --git a/lua/CopilotChat/config/providers.lua b/lua/CopilotChat/config/providers.lua index c34a398b..4dc739bd 100644 --- a/lua/CopilotChat/config/providers.lua +++ b/lua/CopilotChat/config/providers.lua @@ -104,21 +104,22 @@ end ---@class CopilotChat.Provider ---@field disabled nil|boolean ----@field get_headers nil|fun():table,number? ----@field get_agents nil|fun(headers:table):table ----@field get_models nil|fun(headers:table):table ----@field embed nil|string|fun(inputs:table, headers:table):table +---@field get_headers nil|fun(self: CopilotChat.Provider):table,number? +---@field get_agents nil|fun(self: CopilotChat.Provider, headers:table):table +---@field get_models nil|fun(self: CopilotChat.Provider, headers:table):table +---@field embed nil|string|fun(self: CopilotChat.Provider, inputs:table, headers:table):table ---@field prepare_input nil|fun(inputs:table, opts:CopilotChat.Provider.options):table ---@field prepare_output nil|fun(output:table, opts:CopilotChat.Provider.options):CopilotChat.Provider.output ----@field get_url nil|fun(opts:CopilotChat.Provider.options):string +---@field get_url nil|fun(self: CopilotChat.Provider, opts:CopilotChat.Provider.options):string ---@type table local M = {} M.copilot = { embed = 'copilot_embeddings', + api_base = 'https://api.githubcopilot.com', - get_headers = function() + get_headers = function(self) local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', { json_response = true, headers = { @@ -129,6 +130,10 @@ M.copilot = { if err then error(err) end + if response.body.endpoints and response.body.endpoints.api then + ---@diagnostic disable-next-line: inject-field + self.api_base = response.body.endpoints.api + end return { ['Authorization'] = 'Bearer ' .. response.body.token, @@ -139,8 +144,9 @@ M.copilot = { response.body.expires_at end, - get_agents = function(headers) - local response, err = utils.curl_get('https://api.githubcopilot.com/agents', { + get_agents = function(self, headers) + ---@diagnostic disable-next-line: undefined-field + local response, err = utils.curl_get(self.api_base .. '/agents', { json_response = true, headers = headers, }) @@ -158,8 +164,9 @@ M.copilot = { end, response.body.agents) end, - get_models = function(headers) - local response, err = utils.curl_get('https://api.githubcopilot.com/models', { + get_models = function(self, headers) + ---@diagnostic disable-next-line: undefined-field + local response, err = utils.curl_get(self.api_base .. '/models', { json_response = true, headers = headers, }) @@ -197,7 +204,8 @@ M.copilot = { for _, model in ipairs(models) do if not model.policy then - utils.curl_post('https://api.githubcopilot.com/models/' .. model.id .. '/policy', { + ---@diagnostic disable-next-line: undefined-field + utils.curl_post(self.api_base .. '/models/' .. model.id .. '/policy', { headers = headers, json_request = true, body = { state = 'enabled' }, @@ -276,19 +284,21 @@ M.copilot = { } end, - get_url = function(opts) + get_url = function(self, opts) if opts.agent then - return 'https://api.githubcopilot.com/agents/' .. opts.agent.id .. '?chat' + ---@diagnostic disable-next-line: undefined-field + return self.api_base .. '/agents/' .. opts.agent.id .. '?chat' end - return 'https://api.githubcopilot.com/chat/completions' + ---@diagnostic disable-next-line: undefined-field + return self.api_base .. '/chat/completions' end, } M.github_models = { embed = 'copilot_embeddings', - get_headers = function() + get_headers = function(self) return { ['Authorization'] = 'Bearer ' .. get_github_token(), ['x-ms-useragent'] = EDITOR_VERSION, @@ -296,7 +306,7 @@ M.github_models = { } end, - get_models = function(headers) + get_models = function(self, headers) local response, err = utils.curl_post('https://api.catalog.azureml.ms/asset-gallery/v1.0/models', { headers = headers, json_request = true, @@ -344,16 +354,19 @@ M.github_models = { prepare_input = M.copilot.prepare_input, prepare_output = M.copilot.prepare_output, - get_url = function() + get_url = function(self) return 'https://models.inference.ai.azure.com/chat/completions' end, } M.copilot_embeddings = { get_headers = M.copilot.get_headers, + ---@diagnostic disable-next-line: undefined-field + api_base = M.copilot.api_base, - embed = function(inputs, headers) - local response, err = utils.curl_post('https://api.githubcopilot.com/embeddings', { + embed = function(self, inputs, headers) + ---@diagnostic disable-next-line: undefined-field + local response, err = utils.curl_post(self.api_base .. '/embeddings', { headers = headers, json_request = true, json_response = true,