Skip to content

Commit

Permalink
L009: Handle adding newline after {% endif %} at end of file (#2862)
Browse files Browse the repository at this point in the history
* L009: Handle adding newline after {% endif %} at end of file

* Lexer should add trailing placeholders if needed

* Fix L009 to run on the actual final segment, *even if meta*

* Update iter_patches() to return EnrichedFixPatch for better fix placement

* Fix broken test

* Comments, type hints

* First draft of refactor how we handle FixPatch vs EnrichedFixPatch

* Second phase of refactoring FixPatch

Co-authored-by: Barry Hart <barry.hart@mailchimp.com>
  • Loading branch information
barrywhart and Barry Hart authored Mar 15, 2022
1 parent 447ecf8 commit f6f6c2f
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 69 deletions.
18 changes: 0 additions & 18 deletions src/sqlfluff/core/linter/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,3 @@ class ParsedString(NamedTuple):
config: FluffConfig
fname: str
source_str: str


class EnrichedFixPatch(NamedTuple):
"""An edit patch for a source file."""

source_slice: slice
templated_slice: slice
fixed_raw: str
# The patch category, functions mostly for debugging and explanation
# than for function. It allows traceability of *why* this patch was
# generated.
patch_category: str
templated_str: str
source_str: str

def dedupe_tuple(self):
"""Generate a tuple of this fix for deduping."""
return (self.source_slice, self.fixed_raw)
35 changes: 11 additions & 24 deletions src/sqlfluff/core/linter/linted_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from sqlfluff.core.templaters import TemplatedFile

# Classes needed only for type checking
from sqlfluff.core.parser.segments.base import BaseSegment, FixPatch
from sqlfluff.core.parser.segments.base import BaseSegment, FixPatch, EnrichedFixPatch

from sqlfluff.core.linter.common import NoQaDirective, EnrichedFixPatch
from sqlfluff.core.linter.common import NoQaDirective

# Instantiate the linter logger
linter_logger: logging.Logger = logging.getLogger("sqlfluff.linter")
Expand Down Expand Up @@ -203,9 +203,7 @@ def is_clean(self) -> bool:
return not any(self.get_violations(filter_ignore=True))

@staticmethod
def _log_hints(
patch: Union[EnrichedFixPatch, FixPatch], templated_file: TemplatedFile
):
def _log_hints(patch: FixPatch, templated_file: TemplatedFile):
"""Log hints for debugging during patch generation."""
# This next bit is ALL FOR LOGGING AND DEBUGGING
max_log_length = 10
Expand Down Expand Up @@ -279,18 +277,16 @@ def fix_string(self) -> Tuple[Any, bool]:
dedupe_buffer = []
# We use enumerate so that we get an index for each patch. This is entirely
# so when debugging logs we can find a given patch again!
patch: Union[EnrichedFixPatch, FixPatch]
patch: FixPatch # Could be FixPatch or its subclass, EnrichedFixPatch
for idx, patch in enumerate(
self.tree.iter_patches(templated_str=self.templated_file.templated_str)
self.tree.iter_patches(templated_file=self.templated_file)
):
linter_logger.debug(" %s Yielded patch: %s", idx, patch)
self._log_hints(patch, self.templated_file)

# Attempt to convert to source space.
# Get source_slice.
try:
source_slice = self.templated_file.templated_slice_to_source_slice(
patch.templated_slice,
)
enriched_patch = patch.enrich(self.templated_file)
except ValueError: # pragma: no cover
linter_logger.info(
" - Skipping. Source space Value Error. i.e. attempted "
Expand All @@ -301,10 +297,10 @@ def fix_string(self) -> Tuple[Any, bool]:
continue

# Check for duplicates
dedupe_tuple = (source_slice, patch.fixed_raw)
if dedupe_tuple in dedupe_buffer:
if enriched_patch.dedupe_tuple() in dedupe_buffer:
linter_logger.info(
" - Skipping. Source space Duplicate: %s", dedupe_tuple
" - Skipping. Source space Duplicate: %s",
enriched_patch.dedupe_tuple(),
)
continue

Expand All @@ -318,19 +314,10 @@ def fix_string(self) -> Tuple[Any, bool]:

# Get the affected raw slices.
local_raw_slices = self.templated_file.raw_slices_spanning_source_slice(
source_slice
enriched_patch.source_slice
)
local_type_list = [slc.slice_type for slc in local_raw_slices]

enriched_patch = EnrichedFixPatch(
source_slice=source_slice,
templated_slice=patch.templated_slice,
patch_category=patch.patch_category,
fixed_raw=patch.fixed_raw,
templated_str=self.templated_file.templated_str[patch.templated_slice],
source_str=self.templated_file.source_str[source_slice],
)

# Deal with the easy cases of 1) New code at end 2) only literals
if not local_type_list or set(local_type_list) == {"literal"}:
linter_logger.info(
Expand Down
25 changes: 25 additions & 0 deletions src/sqlfluff/core/parser/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,31 @@ def elements_to_segments(
)
)

# Generate placeholders for any source-only slices that *follow*
# the last element. This happens, for example, if a Jinja templated
# file ends with "{% endif %}", and there's no trailing newline.
if idx == len(elements) - 1:
so_slices = [
so
for so in source_only_slices
if so.source_idx >= source_slice.stop
]
for so_slice in so_slices:
segment_buffer.append(
TemplateSegment(
pos_marker=PositionMarker(
slice(so_slice.source_idx, so_slice.end_source_idx()),
slice(
element.template_slice.stop,
element.template_slice.stop,
),
templated_file,
),
source_str=so_slice.raw,
block_type=so_slice.slice_type,
)
)

# Convert to tuple before return
return tuple(segment_buffer)

Expand Down
76 changes: 65 additions & 11 deletions src/sqlfluff/core/parser/segments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
from copy import deepcopy
from dataclasses import dataclass, field, replace
from io import StringIO
from typing import Any, Callable, Dict, Optional, List, Tuple, NamedTuple, Iterator
from typing import (
Any,
Callable,
Dict,
Optional,
List,
Tuple,
Iterator,
Union,
)
import logging

from tqdm import tqdm
Expand All @@ -36,21 +45,54 @@
from sqlfluff.core.parser.matchable import Matchable
from sqlfluff.core.parser.markers import PositionMarker
from sqlfluff.core.parser.context import ParseContext
from sqlfluff.core.templaters.base import TemplatedFile

# Instantiate the linter logger (only for use in methods involved with fixing.)
linter_logger = logging.getLogger("sqlfluff.linter")


class FixPatch(NamedTuple):
@dataclass
class FixPatch:
"""An edit patch for a templated file."""

templated_slice: slice
fixed_raw: str
# The patch category, functions mostly for debugging and explanation
# than for function. It allows traceability of *why* this patch was
# generated. It has no siginificance for processing.
# generated. It has no significance for processing.
patch_category: str

def enrich(self, templated_file: TemplatedFile) -> "EnrichedFixPatch":
"""Convert patch to source space."""
source_slice = templated_file.templated_slice_to_source_slice(
self.templated_slice,
)
return EnrichedFixPatch(
source_slice=source_slice,
templated_slice=self.templated_slice,
patch_category=self.patch_category,
fixed_raw=self.fixed_raw,
templated_str=templated_file.templated_str[self.templated_slice],
source_str=templated_file.source_str[source_slice],
)


@dataclass
class EnrichedFixPatch(FixPatch):
"""An edit patch for a source file."""

source_slice: slice
templated_str: str
source_str: str

def enrich(self, templated_file: TemplatedFile) -> "EnrichedFixPatch":
"""No-op override of base class function."""
return self

def dedupe_tuple(self):
"""Generate a tuple of this fix for deduping."""
return (self.source_slice, self.fixed_raw)


@dataclass
class AnchorEditInfo:
Expand Down Expand Up @@ -1176,7 +1218,9 @@ def _validate_segment_after_fixes(self, rule_code, dialect, fixes_applied, segme
def _log_apply_fixes_check_issue(message, *args): # pragma: no cover
linter_logger.critical(message, *args)

def iter_patches(self, templated_str: str) -> Iterator[FixPatch]:
def iter_patches(
self, templated_file: TemplatedFile
) -> Iterator[Union[EnrichedFixPatch, FixPatch]]:
"""Iterate through the segments generating fix patches.
The patches are generated in TEMPLATED space. This is important
Expand All @@ -1188,6 +1232,7 @@ def iter_patches(self, templated_str: str) -> Iterator[FixPatch]:
"""
# Does it match? If so we can ignore it.
assert self.pos_marker
templated_str = templated_file.templated_str
matches = self.raw == templated_str[self.pos_marker.templated_slice]
if matches:
return
Expand Down Expand Up @@ -1256,7 +1301,7 @@ def iter_patches(self, templated_str: str) -> Iterator[FixPatch]:
insert_buff = ""

# Now we deal with any changes *within* the segment itself.
yield from segment.iter_patches(templated_str=templated_str)
yield from segment.iter_patches(templated_file=templated_file)

# Once we've dealt with any patches from the segment, update
# our position markers.
Expand All @@ -1266,13 +1311,22 @@ def iter_patches(self, templated_str: str) -> Iterator[FixPatch]:
# or insert. Also valid if we still have an insertion buffer here.
end_diff = self.pos_marker.templated_slice.stop - templated_idx
if end_diff or insert_buff:
yield FixPatch(
slice(
self.pos_marker.templated_slice.stop - end_diff,
self.pos_marker.templated_slice.stop,
),
insert_buff,
source_slice = segment.pos_marker.source_slice
templated_slice = slice(
self.pos_marker.templated_slice.stop - end_diff,
self.pos_marker.templated_slice.stop,
)
# By returning an EnrichedFixPatch (rather than FixPatch), which
# includes a source_slice field, we ensure that fixes adjacent
# to source-only slices (e.g. {% endif %}) are placed
# appropriately relative to source-only slices.
yield EnrichedFixPatch(
source_slice=source_slice,
templated_slice=templated_slice,
patch_category="end_point",
fixed_raw=insert_buff,
templated_str=templated_file.templated_str[templated_slice],
source_str=templated_file.source_str[source_slice],
)

def edit(self, raw):
Expand Down
16 changes: 9 additions & 7 deletions src/sqlfluff/core/rules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,16 +656,18 @@ def indent(self) -> str:
space = " "
return space * self.tab_space_size if self.indent_unit == "space" else tab

def is_final_segment(self, context: RuleContext) -> bool:
def is_final_segment(self, context: RuleContext, filter_meta: bool = True) -> bool:
"""Is the current segment the final segment in the parse tree."""
if len(self.filter_meta(context.siblings_post)) > 0:
siblings_post = context.siblings_post
if filter_meta:
siblings_post = self.filter_meta(siblings_post)
if len(siblings_post) > 0:
# This can only fail on the last segment
return False
elif len(context.segment.segments) > 0:
# This can only fail on the last base segment
return False
elif context.segment.is_meta:
# We can't fail on a meta segment
elif filter_meta and context.segment.is_meta:
return False
else:
# We know we are at a leaf of the tree but not necessarily at the end of the
Expand All @@ -674,9 +676,9 @@ def is_final_segment(self, context: RuleContext) -> bool:
# one.
child_segment = context.segment
for parent_segment in context.parent_stack[::-1]:
possible_children = [
s for s in parent_segment.segments if not s.is_meta
]
possible_children = parent_segment.segments
if filter_meta:
possible_children = [s for s in possible_children if not s.is_meta]
if len(possible_children) > possible_children.index(child_segment) + 1:
return False
child_segment = parent_segment
Expand Down
11 changes: 10 additions & 1 deletion src/sqlfluff/core/templaters/slicers/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def _slice_template(self) -> List[RawFileSlice]:
# parts of the tag at a time.
unique_alternate_id = None
alternate_code = None
trimmed_content = ""
if elem_type.endswith("_end") or elem_type == "raw_begin":
block_type = block_types[elem_type]
block_subtype = None
Expand Down Expand Up @@ -436,6 +435,16 @@ def _slice_template(self) -> List[RawFileSlice]:
"endfor",
"endif",
):
# Replace RawSliceInfo for this slice with one that has
# alternate ID and code for tracking. This ensures, for
# instance, that if a file ends with "{% endif %} (with
# no newline following), that we still generate a
# TemplateSliceInfo for it.
unique_alternate_id = self.next_slice_id()
alternate_code = f"{result[-1].raw}\0{unique_alternate_id}_0"
self.raw_slice_info[result[-1]] = RawSliceInfo(
unique_alternate_id, alternate_code, []
)
# Record potential forward jump over this block.
self.raw_slice_info[result[stack[-1]]].next_slice_indices.append(
block_idx
Expand Down
2 changes: 1 addition & 1 deletion src/sqlfluff/rules/L009.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""
# We only care about the final segment of the parse tree.
if not self.is_final_segment(context):
if not self.is_final_segment(context, filter_meta=False):
return None

# Include current segment for complete stack and reverse.
Expand Down
10 changes: 5 additions & 5 deletions test/api/simple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@
"description": "Keywords must be consistently upper case.",
},
{
"code": "L009",
"code": "L014",
"line_no": 1,
"line_pos": 34,
"description": "Files must end with a single trailing newline.",
"description": "Unquoted identifiers must be consistently lower case.",
},
{
"code": "L014",
"code": "L009",
"line_no": 1,
"line_pos": 34,
"description": "Unquoted identifiers must be consistently lower case.",
"line_pos": 41,
"description": "Files must end with a single trailing newline.",
},
]

Expand Down
20 changes: 18 additions & 2 deletions test/core/templaters/jinja_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,10 @@ def test__templater_jinja_slice_template(test, result):
("block_end", slice(113, 127, None), slice(11, 11, None)),
("block_start", slice(27, 46, None), slice(11, 11, None)),
("literal", slice(46, 57, None), slice(11, 22, None)),
("block_end", slice(57, 70, None), slice(22, 22, None)),
("block_start", slice(70, 89, None), slice(22, 22, None)),
("block_end", slice(100, 113, None), slice(22, 22, None)),
("block_end", slice(113, 127, None), slice(22, 22, None)),
],
),
(
Expand Down Expand Up @@ -910,8 +914,20 @@ def test__templater_jinja_slice_template(test, result):
("literal", slice(91, 92, None), slice(0, 0, None)),
("block_end", slice(92, 104, None), slice(0, 0, None)),
("literal", slice(104, 113, None), slice(0, 9, None)),
("templated", slice(113, 139, None), slice(9, 29, None)),
("literal", slice(139, 156, None), slice(29, 46, None)),
("templated", slice(113, 139, None), slice(9, 28, None)),
("literal", slice(139, 156, None), slice(28, 28, None)),
],
),
(
# Test for issue 2822: Handle slicing when there's no newline after
# the Jinja block end.
"{% if true %}\nSELECT 1 + 1\n{%- endif %}",
None,
[
("block_start", slice(0, 13, None), slice(0, 0, None)),
("literal", slice(13, 26, None), slice(0, 13, None)),
("literal", slice(26, 27, None), slice(13, 13, None)),
("block_end", slice(27, 39, None), slice(13, 13, None)),
],
),
],
Expand Down
Loading

0 comments on commit f6f6c2f

Please sign in to comment.