From 9bbffed2e7e344e83ac98bef1b219adc284b61c0 Mon Sep 17 00:00:00 2001 From: Sri Hari Raju Penmatsa Date: Mon, 16 May 2022 14:29:51 +0530 Subject: [PATCH] Interpret "Authorization: Basic :" header as basic auth --- lua/rest-nvim/init.lua | 1 + lua/rest-nvim/request/init.lua | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/lua/rest-nvim/init.lua b/lua/rest-nvim/init.lua index 7e20a162..684149aa 100644 --- a/lua/rest-nvim/init.lua +++ b/lua/rest-nvim/init.lua @@ -25,6 +25,7 @@ rest.run = function(verbose) headers = result.headers, raw = config.get("skip_ssl_verification") and { "-k" } or nil, body = result.body, + auth = result.auth, dry_run = verbose or false, bufnr = result.bufnr, start_line = result.start_line, diff --git a/lua/rest-nvim/request/init.lua b/lua/rest-nvim/request/init.lua index 780b9912..8d0ccdd9 100644 --- a/lua/rest-nvim/request/init.lua +++ b/lua/rest-nvim/request/init.lua @@ -89,12 +89,34 @@ local function is_request_line(line) return false end +-- If the header_line is "Authorization: Basic :", returns ":" else false +local function get_basic_auth_credentials(header_line) + local header = utils.split(header_line, ":", 1) + local header_name = header[1] + local header_value = header[2]:gsub("^%s*", "") + + if header_name ~= "Authorization" then + return false + end + + local resolved_value = utils.replace_vars(header_value) + local i, j = string.find(resolved_value, "Basic ") + local colon_location = string.find(resolved_value, ":") + + if i == 1 and j == 6 and colon_location > 6 then + return string.sub(resolved_value, 7) + else + return false + end +end + -- get_headers retrieves all the found headers and returns a lua table with them -- @param bufnr Buffer number, a.k.a id -- @param start_line Line where the request starts -- @param end_line Line where the request ends local function get_headers(bufnr, start_line, end_line) local headers = {} + local auth = nil local body_start = end_line -- Iterate over all buffer lines starting after the request line @@ -114,16 +136,21 @@ local function get_headers(bufnr, start_line, end_line) end local header = utils.split(line_content, ":") + local basic_auth = get_basic_auth_credentials(line_content) local header_name = header[1]:lower() table.remove(header, 1) local header_value = table.concat(header, ":") if not utils.contains_comments(header_name) then - headers[header_name] = utils.replace_vars(header_value) + if basic_auth then + auth = basic_auth + else + headers[header_name] = utils.replace_vars(header_value) + end end ::continue:: end - return headers, body_start + return headers, auth, body_start end -- start_request will find the request line (e.g. POST http://localhost:8081/foo) @@ -191,7 +218,7 @@ M.get_current_request = function() local parsed_url = parse_url(vim.fn.getline(start_line)) - local headers, body_start = get_headers(bufnr, start_line, end_line) + local headers, auth, body_start = get_headers(bufnr, start_line, end_line) local body = get_body(bufnr, body_start, end_line) @@ -206,6 +233,7 @@ M.get_current_request = function() url = parsed_url.url, headers = headers, body = body, + auth = auth, bufnr = bufnr, start_line = start_line, end_line = end_line,