Skip to content

feat: curl commands from .sh filetype support #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions lua/rest-nvim/parser/curl.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
---@mod rest-nvim.parser.curl rest.nvim curl parsing module
---
---@brief [[
---
--- rest.nvim curl command parsing module
--- rest.nvim uses `tree-sitter-bash` as a core parser to parse raw curl commands
---
---@brief ]]

local curl_parser = {}

local utils = require("rest-nvim.utils")
local logger = require("rest-nvim.logger")

---@param node TSNode Tree-sitter request node
---@param source Source
function curl_parser.parse_command(node, source)
assert(node:type() == "command")
assert(utils.ts_field_text(node, "name", source) == "curl")
local arg_nodes = node:field("argument")
if #arg_nodes < 1 then
logger.error("can't parse curl command with 0 arguments")
return
end
local args = {}
for _, arg_node in ipairs(arg_nodes) do
local arg_type = arg_node:type()
if arg_type == "word" then
table.insert(args, vim.treesitter.get_node_text(arg_node, source))
elseif arg_type == "raw_string" then
-- FIXME: expand escaped sequences like `\n`
table.insert(args, vim.treesitter.get_node_text(arg_node, source):sub(2, -2))
else
logger.error(("can't parse argument type: '%s'"):format(arg_type))
return
end
end
return args
end

-- -X, --request
-- The request method to use.
-- -H, --header
-- The request header to include in the request.
-- -u, --user | --basic | --digest
-- The user's credentials to be provided with the request, and the authorization method to use.
-- -d, --data, --data-ascii | --data-binary | --data-raw | --data-urlencode
-- The data to be sent in a POST request.
-- -F, --form
-- The multipart/form-data message to be sent in a POST request.
-- --url
-- The URL to fetch (mostly used when specifying URLs in a config file).
-- -i, --include
-- Defines whether the HTTP response headers are included in the output.
-- -v, --verbose
-- Enables the verbose operating mode.
-- -L, --location
-- Enables resending the request in case the requested page has moved to a different location.

---@param args string[]
function curl_parser.parse_arguments(args)
local iter = vim.iter(args)
---@type rest.Request
local req = {
-- TODO: add this to rest.Request type
meta = {
redirect = false,
},
url = "",
method = "GET",
headers = {},
cookies = {},
handlers = {},
}
local function any(value, list)
return vim.list_contains(list, value)
end
while true do
local arg = iter:next()
if not arg then
break
end
if any(arg, { "-X", "--request" }) then
req.method = iter:next()
elseif any(arg, { "-H", "--header" }) then
local pair = iter:next()
local key, value = pair:match("(%S+):%s*(.*)")
if not key then
logger.error("can't parse header:" .. pair)
else
key = key:lower()
req.headers[key] = req.headers[key] or {}
if value then
table.insert(req.headers[key], value)
end
end
-- TODO: handle more arguments
-- elseif any(arg, { "-u", "--user" }) then
-- elseif arg == "--basic" then
-- elseif arg == "--digest" then
elseif any(arg, { "-d", "--data", "--data-ascii", "--data-raw" }) then
-- handle external body with `@` syntax
local body = iter:next()
if arg ~= "--data-raw" and body:sub(1, 1) == "@" then
req.body = {
__TYPE = "external",
data = {
name = "",
path = body:sub(2),
},
}
else
req.body = {
__TYPE = "raw",
data = body
}
end
-- elseif arg == "--data-binary" then
-- elseif any(arg, { "-F", "--form" }) then
elseif arg == "--url" then
req.url = iter:next()
elseif any(arg, { "-L", "--location" }) then
req.meta.redirect = true
elseif arg:match("^-%a+$") then
local flags_iter = vim.gsplit(arg:sub(2), "")
for flag in flags_iter do
if flag == "L" then
req.meta.redirect = true
end
end
elseif req.url == "" and not vim.startswith(arg, "-") then
req.url = arg
else
logger.warn("unknown argument: " .. arg)
end
end
return req
end

return curl_parser
44 changes: 18 additions & 26 deletions lua/rest-nvim/parser/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ local NAMED_REQUEST_QUERY = vim.treesitter.query.parse(
]]
)

---@param node TSNode
---@param field string
---@param source Source
---@return string|nil
local function get_node_field_text(node, field, source)
local n = node:field(field)[1]
return n and vim.treesitter.get_node_text(n, source) or nil
end

---@param src string
---@param context rest.Context
---@return string
Expand All @@ -63,8 +54,8 @@ local function parse_headers(req_node, source, context)
end)
local header_nodes = req_node:field("header")
for _, node in ipairs(header_nodes) do
local key = assert(get_node_field_text(node, "name", source))
local value = get_node_field_text(node, "value", source)
local key = assert(utils.ts_field_text(node, "name", source))
local value = utils.ts_field_text(node, "value", source)
key = expand_variables(key, context):lower()
if value then
value = expand_variables(value, context)
Expand Down Expand Up @@ -106,6 +97,7 @@ local function parse_urlencoded_form(str)
logger.error(("Error while parsing query '%s' from urlencoded form '%s'"):format(query_pairs, str))
return nil
end
-- TODO: encode value here
return vim.trim(key) .. "=" .. vim.trim(value)
end)
:join("&")
Expand All @@ -122,7 +114,7 @@ function parser.parse_body(content_type, body_node, source, context)
---@cast body rest.Request.Body
if node_type == "external_body" then
body.__TYPE = "external"
local path = assert(get_node_field_text(body_node, "path", source))
local path = assert(utils.ts_field_text(body_node, "path", source))
if type(source) ~= "number" then
logger.error("can't parse external body on non-existing http file")
return
Expand All @@ -133,7 +125,7 @@ function parser.parse_body(content_type, body_node, source, context)
basepath = basepath:gsub("^" .. vim.pesc(vim.uv.cwd() .. "/"), "")
path = vim.fs.normalize(vim.fs.joinpath(basepath, path))
body.data = {
name = get_node_field_text(body_node, "name", source),
name = utils.ts_field_text(body_node, "name", source),
path = path,
}
elseif node_type == "json_body" or content_type == "application/json" then
Expand Down Expand Up @@ -217,7 +209,7 @@ end
---@param source Source
---@return TSNode[]
function parser.get_all_request_nodes(source)
local _, tree = utils.ts_parse_source(source)
local _, tree = utils.ts_parse_source(source, "http")
local result = {}
for node, _ in tree:root():iter_children() do
if node:type() == "section" and #node:field("request") > 0 then
Expand All @@ -230,7 +222,7 @@ end
---@return TSNode?
function parser.get_request_node_by_name(name)
local source = 0
local _, tree = utils.ts_parse_source(source)
local _, tree = utils.ts_parse_source(source, "http")
local query = NAMED_REQUEST_QUERY
for id, node, _metadata, _match in query:iter_captures(tree:root(), source) do
local capture_name = query.captures[id]
Expand All @@ -248,8 +240,8 @@ end
---@param ctx rest.Context
function parser.parse_variable_declaration(vd_node, source, ctx)
vim.validate({ node = utils.ts_node_spec(vd_node, "variable_declaration") })
local name = assert(get_node_field_text(vd_node, "name", source))
local value = vim.trim(assert(get_node_field_text(vd_node, "value", source)))
local name = assert(utils.ts_field_text(vd_node, "name", source))
local value = vim.trim(assert(utils.ts_field_text(vd_node, "value", source)))
value = expand_variables(value, ctx)
ctx:set_global(name, value)
end
Expand All @@ -261,8 +253,8 @@ end
local function parse_script(node, source)
local lang = "javascript"
local prev_node = utils.ts_upper_node(node)
if prev_node and prev_node:type() == "comment" and get_node_field_text(prev_node, "name", source) == "lang" then
local value = get_node_field_text(prev_node, "value", source)
if prev_node and prev_node:type() == "comment" and utils.ts_field_text(prev_node, "name", source) == "lang" then
local value = utils.ts_field_text(prev_node, "value", source)
if value then
lang = value
end
Expand Down Expand Up @@ -304,7 +296,7 @@ end
---@param source Source
---@return string[]
function parser.get_request_names(source)
local _, tree = utils.ts_parse_source(source)
local _, tree = utils.ts_parse_source(source, "http")
local query = NAMED_REQUEST_QUERY
local result = {}
for id, node, _metadata, _match in query:iter_captures(tree:root(), source) do
Expand Down Expand Up @@ -365,7 +357,7 @@ function parser.parse(node, source, ctx)
local start_row = node:range()
parser.eval_context(source, ctx, start_row)
end
local method = get_node_field_text(req_node, "method", source)
local method = utils.ts_field_text(req_node, "method", source)
if not method then
logger.info("no method provided, falling back to 'GET'")
method = "GET"
Expand All @@ -379,7 +371,7 @@ function parser.parse(node, source, ctx)
for child, _ in node:iter_children() do
local child_type = child:type()
if child_type == "request" then
url = expand_variables(assert(get_node_field_text(req_node, "url", source)), ctx)
url = expand_variables(assert(utils.ts_field_text(req_node, "url", source)), ctx)
url = url:gsub("\n%s+", "")
elseif child_type == "pre_request_script" then
parser.parse_pre_request_script(child, source, ctx)
Expand All @@ -390,9 +382,9 @@ function parser.parse(node, source, ctx)
table.insert(handlers, handler)
end
elseif child_type == "request_separator" then
name = get_node_field_text(child, "value", source)
elseif child_type == "comment" and get_node_field_text(child, "name", source) == "name" then
name = get_node_field_text(child, "value", source) or name
name = utils.ts_field_text(child, "value", source)
elseif child_type == "comment" and utils.ts_field_text(child, "name", source) == "name" then
name = utils.ts_field_text(child, "value", source) or name
elseif child_type == "variable_declaration" then
parser.parse_variable_declaration(child, source, ctx)
end
Expand Down Expand Up @@ -455,7 +447,7 @@ function parser.parse(node, source, ctx)
name = name,
method = method,
url = url,
http_version = get_node_field_text(req_node, "version", source),
http_version = utils.ts_field_text(req_node, "version", source),
headers = headers,
cookies = {},
body = body,
Expand Down
31 changes: 20 additions & 11 deletions lua/rest-nvim/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
---@brief ]]

local logger = require("rest-nvim.logger")
-- local config = require("rest-nvim.config")

local utils = {}

Expand Down Expand Up @@ -86,11 +85,11 @@ end
function utils.parse_http_time(time_str)
local pattern = "(%a+), (%d+)[%s-](%a+)[%s-](%d+) (%d+):(%d+):(%d+) GMT"
local _, day, month_name, year, hour, min, sec = time_str:match(pattern)
-- stylua: ignore
local months = {
Jan = 1, Feb = 2, Mar = 3, Apr = 4, May = 5, Jun = 6,
Jul = 7, Aug = 8, Sep = 9, Oct = 10, Nov = 11, Dec = 12,
}
-- stylua: ignore
local months = {
Jan = 1, Feb = 2, Mar = 3, Apr = 4, May = 5, Jun = 6,
Jul = 7, Aug = 8, Sep = 9, Oct = 10, Nov = 11, Dec = 12,
}
local time_table = {
year = tonumber(year),
month = months[month_name],
Expand Down Expand Up @@ -186,20 +185,21 @@ function utils.ts_highlight_node(bufnr, node, ns, timeout)
end

---@param source string|integer
---@param lang string
---@return vim.treesitter.LanguageTree
function utils.ts_get_parser(source)
function utils.ts_get_parser(source, lang)
if type(source) == "string" then
return vim.treesitter.get_string_parser(source, "http")
return vim.treesitter.get_string_parser(source, lang)
else
return vim.treesitter.get_parser(source, "http")
return vim.treesitter.get_parser(source, lang)
end
end

---@param source string|integer
---@return vim.treesitter.LanguageTree
---@return TSTree
function utils.ts_parse_source(source)
local ts_parser = utils.ts_get_parser(source)
function utils.ts_parse_source(source, lang)
local ts_parser = utils.ts_get_parser(source, lang)
return ts_parser, assert(ts_parser:parse(false)[1])
end

Expand Down Expand Up @@ -238,6 +238,15 @@ function utils.ts_upper_node(node)
return min_node
end

---@param node TSNode
---@param field string
---@param source Source
---@return string|nil
function utils.ts_field_text(node, field, source)
local n = node:field(field)[1]
return n and vim.treesitter.get_node_text(n, source) or nil
end

---@param node TSNode
---@param expected_type string
---@return table
Expand Down
4 changes: 2 additions & 2 deletions spec/examples/examples_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
describe("multi-line-url", function()
it("line breaks should be ignored", function()
local source = open("spec/examples/multi_line_url.http")
local _, tree = utils.ts_parse_source(source)
local _, tree = utils.ts_parse_source(source, "http")
local req_node = assert(tree:root():child(0))
local req = parser.parse(req_node, source)
assert.not_nil(req)
Expand Down Expand Up @@ -108,7 +108,7 @@ describe("builtin request hooks", function()
describe("set_content_type", function()
it("with external body", function()
local source = open("spec/examples/post_with_external_body.http")
local _, tree = utils.ts_parse_source(source)
local _, tree = utils.ts_parse_source(source, "http")
local req_node = assert(tree:root():child(0))
local req = assert(parser.parse(req_node, source))
_G.rest_request = req
Expand Down
Loading