Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check records nominally when resolving metamethods for operators #710

Merged
merged 1 commit into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions spec/metamethods/add_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,35 @@ describe("binary metamethod __add", function()

print((10 + s).x)
]]))

it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[
local type Temperature = record
n: number
metamethod __add: function(t1: Temperature, t2: Temperature): Temperature
end

local type Date = record
n: number
metamethod __add: function(t1: Date, t2: Date): Date
end

local temp2: Temperature = { n = 45 }
local birthday2 : Date = { n = 34 }

setmetatable(temp2, {
__add = function(t1: Temperature, t2: Temperature): Temperature
return { n = t1.n + t2.n }
end,
})

setmetatable(birthday2, {
__add = function(t1: Date, t2: Date): Date
return { n = t1.n + t2.n }
end,
})

print((temp2 + birthday2).n)
]], {
{ y = 26, msg = "Date is not a Temperature" },
}))
end)
14 changes: 7 additions & 7 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7621,7 +7621,7 @@ tl.type_check = function(ast, opts)
end
end

local function check_metamethod(node, op, a, b)
local function check_metamethod(node, op, a, b, orig_a, orig_b)
local method_name
local where_args
local args
Expand All @@ -7636,11 +7636,11 @@ tl.type_check = function(ast, opts)
if a and b then
method_name = binop_to_metamethod[op]
where_args = { node.e1, node.e2 }
args = { typename = "tuple", a, b }
args = { typename = "tuple", orig_a, orig_b }
else
method_name = unop_to_metamethod[op]
where_args = { node.e1 }
args = { typename = "tuple", a }
args = { typename = "tuple", orig_a }
end

local metamethod = a.meta_fields and a.meta_fields[method_name or ""]
Expand Down Expand Up @@ -7679,7 +7679,7 @@ tl.type_check = function(ast, opts)
return tbl.fields[key]
end

local meta_t = check_metamethod(rec, "@index", tbl, STRING)
local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -8050,7 +8050,7 @@ tl.type_check = function(ast, opts)
errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b
end

local meta_t = check_metamethod(anode, "@index", a, orig_b)
local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -10004,7 +10004,7 @@ tl.type_check = function(ast, opts)
node.type = types_op[a.typename]
local meta_on_operator
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a))
end
Expand Down Expand Up @@ -10050,7 +10050,7 @@ tl.type_check = function(ast, opts)
node.type = types_op[a.typename] and types_op[a.typename][b.typename]
local meta_on_operator
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b))
if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then
Expand Down
14 changes: 7 additions & 7 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -7621,7 +7621,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
end
end

local function check_metamethod(node: Node, op: string, a: Type, b: Type): Type, integer
local function check_metamethod(node: Node, op: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer
local method_name: string
local where_args: {Node}
local args: Type
Expand All @@ -7636,11 +7636,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
if a and b then
method_name = binop_to_metamethod[op]
where_args = { node.e1, node.e2 }
args = { typename = "tuple", a, b }
args = { typename = "tuple", orig_a, orig_b }
else
method_name = unop_to_metamethod[op]
where_args = { node.e1 }
args = { typename = "tuple", a }
args = { typename = "tuple", orig_a }
end

local metamethod = a.meta_fields and a.meta_fields[method_name or ""]
Expand Down Expand Up @@ -7679,7 +7679,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
return tbl.fields[key]
end

local meta_t = check_metamethod(rec, "@index", tbl, STRING)
local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -8050,7 +8050,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b
end

local meta_t = check_metamethod(anode, "@index", a, orig_b)
local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -10004,7 +10004,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
node.type = types_op[a.typename]
local meta_on_operator: integer
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a))
end
Expand Down Expand Up @@ -10050,7 +10050,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
node.type = types_op[a.typename] and types_op[a.typename][b.typename]
local meta_on_operator: integer
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b))
if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then
Expand Down
Loading