From f6e7e93e00dd43b809e5f60f082693b734e37228 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Tue, 5 Jul 2022 10:51:24 -0700 Subject: [PATCH] chore(rls): Remove passing global username (#20344) * chore(rls): Remove passing global username * Update manager.py * Update manager.py * Update manager.py * Update manager.py Co-authored-by: John Bodley (cherry picked from commit ad308fbde251d0ed262a90b8d818c977dfe73d0e) --- superset/connectors/sqla/models.py | 5 +---- superset/security/manager.py | 19 +++++-------------- superset/sql_lab.py | 1 - superset/sql_parse.py | 6 ++---- tests/unit_tests/sql_parse_tests.py | 1 - 5 files changed, 8 insertions(+), 24 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 61d708f021f3e..67a2f97c841d0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1169,7 +1169,6 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool: def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor, - username: Optional[str] = None, ) -> List[TextClause]: """ Return the appropriate row level security filters for this table and the @@ -1177,14 +1176,12 @@ def get_sqla_row_level_filters( Flask global namespace. :param template_processor: The template processor to apply to the filters. - :param username: Optional username if there's no user in the Flask global - namespace. :returns: A list of SQL clauses to be ANDed together. """ all_filters: List[TextClause] = [] filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) try: - for filter_ in security_manager.get_rls_filters(self, username): + for filter_ in security_manager.get_rls_filters(self): clause = self.text( f"({template_processor.process_template(filter_.clause)})" ) diff --git a/superset/security/manager.py b/superset/security/manager.py index 890a09415ecb3..f3d717b7065e9 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1147,25 +1147,16 @@ def get_guest_rls_filters( ] return [] - def get_rls_filters( - self, - table: "BaseDatasource", - username: Optional[str] = None, - ) -> List[SqlaQuery]: + def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and the passed table. - :param BaseDatasource table: The table to check against. - :param Optional[str] username: Optional username if there's no user in the Flask - global namespace. + :param table: The table to check against :returns: A list of filters """ - if hasattr(g, "user"): - user = g.user - elif username: - user = self.find_user(username=username) - else: + + if not (hasattr(g, "user") and g.user is not None): return [] # pylint: disable=import-outside-toplevel @@ -1175,7 +1166,7 @@ def get_rls_filters( RowLevelSecurityFilter, ) - user_roles = [role.id for role in self.get_user_roles(user)] + user_roles = [role.id for role in self.get_user_roles(g.user)] regular_filter_roles = ( self.get_session() .query(RLSFilterRoles.c.rls_filter_id) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 785d16327f7f2..571fd94219aa7 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -208,7 +208,6 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem parsed_query._parsed[0], # pylint: disable=protected-access database.id, query.schema, - username=get_username(), ) ) ) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index b585810f785a2..d377986f56573 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -553,7 +553,6 @@ def get_rls_for_table( candidate: Token, database_id: int, default_schema: Optional[str], - username: Optional[str] = None, ) -> Optional[TokenList]: """ Given a table name, return any associated RLS predicates. @@ -586,7 +585,7 @@ def get_rls_for_table( template_processor = dataset.get_template_processor() predicate = " AND ".join( str(filter_) - for filter_ in dataset.get_sqla_row_level_filters(template_processor, username) + for filter_ in dataset.get_sqla_row_level_filters(template_processor) ) if not predicate: return None @@ -601,7 +600,6 @@ def insert_rls( token_list: TokenList, database_id: int, default_schema: Optional[str], - username: Optional[str] = None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -623,7 +621,7 @@ def insert_rls( elif state == InsertRLSState.SEEN_SOURCE and ( isinstance(token, Identifier) or token.ttype == Keyword ): - rls = get_rls_for_table(token, database_id, default_schema, username) + rls = get_rls_for_table(token, database_id, default_schema) if rls: state = InsertRLSState.FOUND_TABLE diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 1d2c788496af0..98eceebd47136 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1409,7 +1409,6 @@ def get_rls_for_table( candidate: Token, database_id: int, default_schema: str, - username: Optional[str] = None, ) -> Optional[TokenList]: """ Return the RLS ``condition`` if ``candidate`` matches ``table``.