Skip to content

Commit 8f96025

Browse files
authored
Merge pull request #2859 from tomlau10/feat/param_infer_for_override
Infer function parameter types when overriding the same-named class function in an instance of that class
2 parents fbd5bb1 + ee5e872 commit 8f96025

File tree

5 files changed

+90
-15
lines changed

5 files changed

+90
-15
lines changed

Diff for: changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Unreleased
44
<!-- Add all new changes here. They will be moved under a version at release -->
55
* `NEW` Added support for Japanese locale
6+
* `NEW` Infer function parameter types when overriding the same-named class function in an instance of that class [#2158](https://github.com/LuaLS/lua-language-server/issues/2158)
67
* `FIX` Eliminate floating point error in test benchmark output
78
* `FIX` Remove luamake install from make scripts
89

Diff for: script/core/diagnostics/duplicate-set-field.lua

+6
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ return function (uri, callback)
6868
if not defValue or defValue.type ~= 'function' then
6969
goto CONTINUE
7070
end
71+
if vm.getDefinedClass(guide.getUri(def), def.node)
72+
and not vm.getDefinedClass(guide.getUri(src), src.node)
73+
then
74+
-- allow type variable to override function defined in class variable
75+
goto CONTINUE
76+
end
7177
callback {
7278
start = src.start,
7379
finish = src.finish,

Diff for: script/vm/compiler.lua

+45-15
Original file line numberDiff line numberDiff line change
@@ -1117,26 +1117,56 @@ local function compileFunctionParam(func, source)
11171117
end
11181118
---@cast aindex integer
11191119

1120-
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
11211120
local funcNode = vm.compileNode(func)
1122-
local found = false
1123-
for n in funcNode:eachObject() do
1124-
if n.type == 'doc.type.function' and n.args[aindex] then
1125-
local argNode = vm.compileNode(n.args[aindex])
1126-
for an in argNode:eachObject() do
1127-
if an.type ~= 'doc.generic.name' then
1128-
vm.setNode(source, an)
1121+
if func.parent.type == 'callargs' then
1122+
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
1123+
for n in funcNode:eachObject() do
1124+
if n.type == 'doc.type.function' and n.args[aindex] then
1125+
local argNode = vm.compileNode(n.args[aindex])
1126+
for an in argNode:eachObject() do
1127+
if an.type ~= 'doc.generic.name' then
1128+
vm.setNode(source, an)
1129+
end
11291130
end
1130-
end
1131-
-- NOTE: keep existing behavior for local call which only set type based on the 1st match
1132-
if func.parent.type == 'callargs' then
1131+
-- NOTE: keep existing behavior for function as argument which only set type based on the 1st match
11331132
return true
11341133
end
1135-
found = true
11361134
end
1137-
end
1138-
if found then
1139-
return true
1135+
else
1136+
-- function declaration: use info from all `fun()`, also from the base function when overriding
1137+
--[[
1138+
---@type fun(x: string)|fun(x: number)
1139+
local function f1(x) end --> x -> string|number
1140+
1141+
---@overload fun(x: string)
1142+
---@overload fun(x: number)
1143+
local function f2(x) end --> x -> string|number
1144+
1145+
---@class A
1146+
local A = {}
1147+
---@param x number
1148+
function A:f(x) end --> x -> number
1149+
---@type A
1150+
local a = {}
1151+
function a:f(x) end --> x -> number
1152+
]]
1153+
local found = false
1154+
for n in funcNode:eachObject() do
1155+
if (n.type == 'doc.type.function' or n.type == 'function')
1156+
and n.args[aindex] and n.args[aindex] ~= source
1157+
then
1158+
local argNode = vm.compileNode(n.args[aindex])
1159+
for an in argNode:eachObject() do
1160+
if an.type ~= 'doc.generic.name' then
1161+
vm.setNode(source, an)
1162+
end
1163+
end
1164+
found = true
1165+
end
1166+
end
1167+
if found then
1168+
return true
1169+
end
11401170
end
11411171

11421172
local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')

Diff for: test/diagnostics/duplicate-set-field.lua

+26
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,29 @@ else
7272
function X.f() end
7373
end
7474
]]
75+
76+
TEST [[
77+
---@class A
78+
X = {}
79+
80+
function X:f() end
81+
82+
---@type x
83+
local x
84+
85+
function x:f() end
86+
]]
87+
88+
TEST [[
89+
---@class A
90+
X = {}
91+
92+
function X:f() end
93+
94+
---@type x
95+
local x
96+
97+
function <!x:f!>() end
98+
99+
function <!x:f!>() end
100+
]]

Diff for: test/type_inference/common.lua

+12
Original file line numberDiff line numberDiff line change
@@ -4441,3 +4441,15 @@ local B = {}
44414441
44424442
function B:func(<?x?>) end
44434443
]]
4444+
4445+
TEST 'number' [[
4446+
---@class A
4447+
local A = {}
4448+
4449+
---@param x number
4450+
function A:func(x) end
4451+
4452+
---@type A
4453+
local a = {}
4454+
function a:func(<?x?>) end
4455+
]]

0 commit comments

Comments
 (0)