Skip to content

Commit

Permalink
fix: check records nominally when resolving metamethods for operators (
Browse files Browse the repository at this point in the history
  • Loading branch information
hishamhm authored Oct 21, 2023
1 parent 821e161 commit a5dd26b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 14 deletions.
31 changes: 31 additions & 0 deletions spec/metamethods/add_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,35 @@ describe("binary metamethod __add", function()
print((10 + s).x)
]]))

it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[
local type Temperature = record
n: number
metamethod __add: function(t1: Temperature, t2: Temperature): Temperature
end
local type Date = record
n: number
metamethod __add: function(t1: Date, t2: Date): Date
end
local temp2: Temperature = { n = 45 }
local birthday2 : Date = { n = 34 }
setmetatable(temp2, {
__add = function(t1: Temperature, t2: Temperature): Temperature
return { n = t1.n + t2.n }
end,
})
setmetatable(birthday2, {
__add = function(t1: Date, t2: Date): Date
return { n = t1.n + t2.n }
end,
})
print((temp2 + birthday2).n)
]], {
{ y = 26, msg = "Date is not a Temperature" },
}))
end)
14 changes: 7 additions & 7 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7621,7 +7621,7 @@ tl.type_check = function(ast, opts)
end
end

local function check_metamethod(node, op, a, b)
local function check_metamethod(node, op, a, b, orig_a, orig_b)
local method_name
local where_args
local args
Expand All @@ -7636,11 +7636,11 @@ tl.type_check = function(ast, opts)
if a and b then
method_name = binop_to_metamethod[op]
where_args = { node.e1, node.e2 }
args = { typename = "tuple", a, b }
args = { typename = "tuple", orig_a, orig_b }
else
method_name = unop_to_metamethod[op]
where_args = { node.e1 }
args = { typename = "tuple", a }
args = { typename = "tuple", orig_a }
end

local metamethod = a.meta_fields and a.meta_fields[method_name or ""]
Expand Down Expand Up @@ -7679,7 +7679,7 @@ tl.type_check = function(ast, opts)
return tbl.fields[key]
end

local meta_t = check_metamethod(rec, "@index", tbl, STRING)
local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -8050,7 +8050,7 @@ tl.type_check = function(ast, opts)
errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b
end

local meta_t = check_metamethod(anode, "@index", a, orig_b)
local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -10004,7 +10004,7 @@ tl.type_check = function(ast, opts)
node.type = types_op[a.typename]
local meta_on_operator
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a))
end
Expand Down Expand Up @@ -10050,7 +10050,7 @@ tl.type_check = function(ast, opts)
node.type = types_op[a.typename] and types_op[a.typename][b.typename]
local meta_on_operator
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b))
if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then
Expand Down
14 changes: 7 additions & 7 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -7621,7 +7621,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
end
end

local function check_metamethod(node: Node, op: string, a: Type, b: Type): Type, integer
local function check_metamethod(node: Node, op: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer
local method_name: string
local where_args: {Node}
local args: Type
Expand All @@ -7636,11 +7636,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
if a and b then
method_name = binop_to_metamethod[op]
where_args = { node.e1, node.e2 }
args = { typename = "tuple", a, b }
args = { typename = "tuple", orig_a, orig_b }
else
method_name = unop_to_metamethod[op]
where_args = { node.e1 }
args = { typename = "tuple", a }
args = { typename = "tuple", orig_a }
end

local metamethod = a.meta_fields and a.meta_fields[method_name or ""]
Expand Down Expand Up @@ -7679,7 +7679,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
return tbl.fields[key]
end

local meta_t = check_metamethod(rec, "@index", tbl, STRING)
local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -8050,7 +8050,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b
end

local meta_t = check_metamethod(anode, "@index", a, orig_b)
local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b)
if meta_t then
return meta_t
end
Expand Down Expand Up @@ -10004,7 +10004,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
node.type = types_op[a.typename]
local meta_on_operator: integer
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a))
end
Expand Down Expand Up @@ -10050,7 +10050,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
node.type = types_op[a.typename] and types_op[a.typename][b.typename]
local meta_on_operator: integer
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b)
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b)
if not node.type then
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b))
if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then
Expand Down

0 comments on commit a5dd26b

Please sign in to comment.