Skip to content

Commit

Permalink
fix nonreentrant analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Apr 25, 2024
1 parent 1592ea1 commit 3762d41
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 20 deletions.
8 changes: 1 addition & 7 deletions vyper/semantics/analysis/data_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Generic, TypeVar

from vyper import ast as vy_ast
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StorageLayoutException
from vyper.semantics.analysis.base import VarOffset
from vyper.semantics.analysis.utils import get_reentrancy_key_location
from vyper.semantics.data_locations import DataLocation
from vyper.typing import StorageLayout

Expand Down Expand Up @@ -216,12 +216,6 @@ def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]:
return [node for node in vyper_module.body if isinstance(node, allocable)]


def get_reentrancy_key_location() -> DataLocation:
if version_check(begin="cancun"):
return DataLocation.TRANSIENT
return DataLocation.STORAGE


_LAYOUT_KEYS = {
DataLocation.CODE: "code_layout",
DataLocation.TRANSIENT: "transient_storage_layout",
Expand Down
11 changes: 2 additions & 9 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,19 +782,12 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
raise CallViolation(msg, node.parent, hint=hint)

if not func_type.from_interface:
for s in func_type.get_variable_writes():
if s.variable.is_state_variable():
func_info._writes.add(s)
for s in func_type.get_variable_reads():
if s.variable.is_state_variable():
func_info._reads.add(s)
func_info._writes.update(func_type.get_variable_writes())
func_info._reads.update(func_type.get_variable_reads())

if self.function_analyzer:
self._check_call_mutability(func_type.mutability)

if func_type.uses_state():
self.function_analyzer._handle_module_access(node.func)

if func_type.is_deploy and not self.func.is_deploy:
raise CallViolation(
f"Cannot call an @{func_type.visibility} function from "
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,12 @@ def validate_used_modules(self):
all_used_modules = OrderedSet()

for f in module_t.functions.values():
for u in f.get_used_modules():
all_used_modules.add(u.module_t)
all_used_modules.update([u.module_t for u in f.get_used_modules()])

for decl in module_t.exports_decls:
info = decl._metadata["exports_info"]
for f in info.functions:
all_used_modules.update([u.module_t for u in f.get_used_modules()])
all_used_modules.update([u.module_t for u in info.used_modules])

for used_module in all_used_modules:
Expand Down
16 changes: 15 additions & 1 deletion vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Iterable, List

from vyper import ast as vy_ast
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CompilerPanic,
InvalidLiteral,
Expand All @@ -17,7 +18,14 @@
ZeroDivisionException,
)
from vyper.semantics import types
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo
from vyper.semantics.analysis.base import (
DataLocation,
ExprInfo,
Modifiability,
ModuleInfo,
VarAccess,
VarInfo,
)

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.base
begins an import cycle.
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
Expand Down Expand Up @@ -52,6 +60,12 @@ def uses_state(var_accesses: Iterable[VarAccess]) -> bool:
return any(s.variable.is_state_variable() for s in var_accesses)


def get_reentrancy_key_location() -> DataLocation:
if version_check(begin="cancun"):
return DataLocation.TRANSIENT
return DataLocation.STORAGE


class _ExprAnalyser:
"""
Node type-checker class.
Expand Down
14 changes: 13 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
ModuleInfo,
StateMutability,
VarAccess,
VarInfo,
VarOffset,
)
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
get_reentrancy_key_location,
uses_state,
validate_expected_type,
)
Expand Down Expand Up @@ -131,6 +133,16 @@ def __init__(
# reads of variables from this function
self._variable_reads: OrderedSet[VarAccess] = OrderedSet()

if nonreentrant:
location = get_reentrancy_key_location()
# dummy varinfo object. it doesn't matter where location is,
# so long as it registers as a state variable
dummy_varinfo = VarInfo(typ=self, location=location, decl_node=ast_def) # type: ignore
nonreentrant_access = VarAccess(dummy_varinfo, path=())
self._variable_reads.add(nonreentrant_access)
if self.is_mutable:
self._variable_writes.add(nonreentrant_access)

# list of modules used (accessed state) by this function
self._used_modules: OrderedSet[ModuleInfo] = OrderedSet()

Expand Down Expand Up @@ -165,7 +177,7 @@ def get_variable_accesses(self):
return self._variable_reads | self._variable_writes

def uses_state(self):
return self.nonreentrant or uses_state(self.get_variable_accesses())
return uses_state(self.get_variable_accesses())

def get_used_modules(self):
# _used_modules is populated during analysis
Expand Down

0 comments on commit 3762d41

Please sign in to comment.