Skip to content

Commit

Permalink
Introduce CompiledCGPatterns capturing the TS-Query and regex (placeh…
Browse files Browse the repository at this point in the history
…older)

ghstack-source-id: cf2baf162fcbfc57e097e1fbc6e90c5ae2869b74
Pull Request resolved: #527
  • Loading branch information
ketkarameya committed Jul 5, 2023
1 parent 3368185 commit 10a862b
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 69 deletions.
61 changes: 59 additions & 2 deletions src/models/capture_group_patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ Copyright (c) 2023 Uber Technologies, Inc.
use crate::{
models::Validator,
utilities::{
tree_sitter_utilities::{get_ts_query_parser, number_of_errors},
tree_sitter_utilities::{get_all_matches_for_query, get_ts_query_parser, number_of_errors},
Instantiate,
},
};
use pyo3::prelude::pyclass;
use regex::Regex;
use serde_derive::Deserialize;
use std::collections::HashMap;
use tree_sitter::{Node, Query};

use super::matches::Match;

#[pyclass]
#[derive(Deserialize, Debug, Clone, Default, PartialEq, Hash, Eq)]
Expand All @@ -38,12 +42,18 @@ impl CGPattern {

impl Validator for CGPattern {
fn validate(&self) -> Result<(), String> {
if self.pattern().starts_with("rgx ") {
panic!("Regex not supported")
}
let mut parser = get_ts_query_parser();
parser
.parse(self.pattern(), None)
.filter(|x| number_of_errors(&x.root_node()) == 0)
.map(|_| Ok(()))
.unwrap_or(Err(format!("Cannot parse - {}", self.pattern())))
.unwrap_or(Err(format!(
"Cannot parse the tree-sitter query - {}",
self.pattern()
)))
}
}

Expand All @@ -56,3 +66,50 @@ impl Instantiate for CGPattern {
CGPattern::new(self.pattern().instantiate(&substitutions))
}
}

#[derive(Debug)]
pub(crate) enum CompiledCGPattern {
Q(Query),
R(Regex), // Regex is not yet supported
}

impl CompiledCGPattern {
/// Applies the query upon the given node, and gets all the matches
/// # Arguments
/// * `node` - the root node to apply the query upon
/// * `source_code` - the corresponding source code string for the node.
/// * `query` - the query to be applied
/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`.
///
/// # Returns
/// A vector of `tuples` containing the range of the matches in the source code and the corresponding mapping for the tags (to code snippets).
/// By default it returns the range of the outermost node for each query match.
/// If `replace_node` is provided in the rule, it returns the range of the node corresponding to that tag.
pub(crate) fn get_match(&self, node: &Node, source_code: &str, recursive: bool) -> Option<Match> {
if let Some(m) = self
.get_matches(node, source_code.to_string(), recursive, None, None)
.first()
{
return Some(m.clone());
}
None
}

/// Applies the pattern upon the given `node`, and gets all the matches
pub(crate) fn get_matches(
&self, node: &Node, source_code: String, recursive: bool, replace_node: Option<String>,
replace_node_idx: Option<u8>,
) -> Vec<Match> {
match self {
CompiledCGPattern::Q(query) => get_all_matches_for_query(
node,
source_code,
query,
recursive,
replace_node,
replace_node_idx,
),
CompiledCGPattern::R(_) => panic!("Regex is not yet supported!!!"),
}
}
}
21 changes: 5 additions & 16 deletions src/models/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ use pyo3::prelude::{pyclass, pymethods};
use serde_derive::Deserialize;
use tree_sitter::Node;

use crate::utilities::{
gen_py_str_methods,
tree_sitter_utilities::{get_all_matches_for_query, get_match_for_query, get_node_for_range},
};
use crate::utilities::{gen_py_str_methods, tree_sitter_utilities::get_node_for_range};

use super::{
capture_group_patterns::CGPattern, default_configs::default_child_count,
Expand Down Expand Up @@ -415,9 +412,8 @@ impl SourceCodeUnit {
}

while let Some(parent) = current_node.parent() {
if let Some(p_match) =
get_match_for_query(&parent, self.code(), rule_store.query(ts_query), false)
{
let pattern = rule_store.query(ts_query);
if let Some(p_match) = pattern.get_match(&parent, self.code(), false) {
let matched_ancestor = get_node_for_range(
self.root_node(),
p_match.range().start_byte,
Expand All @@ -442,14 +438,7 @@ impl SourceCodeUnit {

// Retrieve all matches within the ancestor node
let contains_query = &rule_store.query(filter.contains());
let matches = get_all_matches_for_query(
ancestor,
self.code().to_string(),
contains_query,
true,
None,
None,
);
let matches = contains_query.get_matches(ancestor, self.code().to_string(), true, None, None);
let at_least = filter.at_least as usize;
let at_most = filter.at_most as usize;
// Validate if the count of matches falls within the expected range
Expand All @@ -464,7 +453,7 @@ impl SourceCodeUnit {
// Check if there's a match within the scope node
// If one of the filters is not satisfied, return false
let query = &rule_store.query(ts_query);
if get_match_for_query(ancestor, self.code(), query, true).is_some() {
if query.get_match(ancestor, self.code(), true).is_some() {
return false;
}
}
Expand Down
10 changes: 4 additions & 6 deletions src/models/matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ use pyo3::prelude::{pyclass, pymethods};
use serde_derive::{Deserialize, Serialize};
use tree_sitter::Node;

use crate::utilities::{
gen_py_str_methods,
tree_sitter_utilities::{get_all_matches_for_query, get_node_for_range},
};
use crate::utilities::{gen_py_str_methods, tree_sitter_utilities::get_node_for_range};

use super::{
piranha_arguments::PiranhaArguments, rule::InstantiatedRule, rule_store::RuleStore,
Expand Down Expand Up @@ -291,10 +288,11 @@ impl SourceCodeUnit {
} else {
(rule.replace_node(), rule.replace_idx())
};
let mut all_query_matches = get_all_matches_for_query(

let pattern = rule_store.query(&rule.query());
let mut all_query_matches = pattern.get_matches(
&node,
self.code().to_string(),
rule_store.query(&rule.query()),
recursive,
replace_node_tag,
replace_node_idx,
Expand Down
20 changes: 13 additions & 7 deletions src/models/rule_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@ use itertools::Itertools;
use jwalk::WalkDir;
use log::{debug, trace};
use regex::Regex;
use tree_sitter::Query;

use crate::{
models::capture_group_patterns::CGPattern, models::piranha_arguments::PiranhaArguments,
models::scopes::ScopeQueryGenerator, utilities::read_file,
};

use super::{language::PiranhaLanguage, rule::InstantiatedRule};
use super::{
capture_group_patterns::CompiledCGPattern, language::PiranhaLanguage, rule::InstantiatedRule,
};
use glob::Pattern;

/// This maintains the state for Piranha.
#[derive(Debug, Getters, Default)]
pub(crate) struct RuleStore {
// Caches the compiled tree-sitter queries.
rule_query_cache: HashMap<String, Query>,
rule_query_cache: HashMap<String, CompiledCGPattern>,
// Current global rules to be applied.
#[get = "pub"]
global_rules: Vec<InstantiatedRule>,
Expand Down Expand Up @@ -75,11 +76,16 @@ impl RuleStore {

/// Get the compiled query for the `query_str` from the cache
/// else compile it, add it to the cache and return it.
pub(crate) fn query(&mut self, query_str: &CGPattern) -> &Query {
self
pub(crate) fn query(&mut self, cg_pattern: &CGPattern) -> &CompiledCGPattern {
let pattern = cg_pattern.pattern();
if pattern.starts_with("rgx ") {
panic!("Regex not supported.")
}

&*self
.rule_query_cache
.entry(query_str.pattern())
.or_insert_with(|| self.language.create_query(query_str.pattern()))
.entry(pattern.to_string())
.or_insert_with(|| CompiledCGPattern::Q(self.language.create_query(pattern)))
}

// For the given scope level, get the ScopeQueryGenerator from the `scope_config.toml` file
Expand Down
9 changes: 2 additions & 7 deletions src/models/scopes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Copyright (c) 2023 Uber Technologies, Inc.

use super::capture_group_patterns::CGPattern;
use super::{rule_store::RuleStore, source_code_unit::SourceCodeUnit};
use crate::utilities::tree_sitter_utilities::get_match_for_query;
use crate::utilities::tree_sitter_utilities::get_node_for_range;
use crate::utilities::Instantiate;
use derive_builder::Builder;
Expand Down Expand Up @@ -65,12 +64,8 @@ impl SourceCodeUnit {
changed_node.kind()
);
for m in &scope_enclosing_nodes {
if let Some(p_match) = get_match_for_query(
&changed_node,
self.code(),
rules_store.query(m.enclosing_node()),
false,
) {
let pattern = rules_store.query(m.enclosing_node());
if let Some(p_match) = pattern.get_match(&changed_node, self.code(), false) {
// Generate the scope query for the specific context by substituting the
// the tags with code snippets appropriately in the `generator` query.
return m.scope().instantiate(p_match.matches());
Expand Down
12 changes: 3 additions & 9 deletions src/models/source_code_unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use crate::{
models::capture_group_patterns::CGPattern,
models::rule_graph::{GLOBAL, PARENT},
utilities::tree_sitter_utilities::{
get_match_for_query, get_node_for_range, get_replace_range, get_tree_sitter_edit,
number_of_errors,
get_node_for_range, get_replace_range, get_tree_sitter_edit, number_of_errors,
},
};

Expand Down Expand Up @@ -299,13 +298,8 @@ impl SourceCodeUnit {
// let mut scope_node = self.root_node();
if let Some(query_str) = scope_query {
// Apply the scope query in the source code and get the appropriate node
let tree_sitter_scope_query = rules_store.query(query_str);
if let Some(p_match) = get_match_for_query(
&self.root_node(),
self.code(),
tree_sitter_scope_query,
true,
) {
let scope_pattern = rules_store.query(query_str);
if let Some(p_match) = scope_pattern.get_match(&self.root_node(), self.code(), true) {
return get_node_for_range(
self.root_node(),
p_match.range().start_byte,
Expand Down
22 changes: 0 additions & 22 deletions src/utilities/tree_sitter_utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@ use std::collections::HashMap;
use tree_sitter::{InputEdit, Node, Parser, Point, Query, QueryCapture, QueryCursor, Range};
use tree_sitter_traversal::{traverse, Order};

/// Applies the query upon the given node, and gets all the matches
/// # Arguments
/// * `node` - the root node to apply the query upon
/// * `source_code` - the corresponding source code string for the node.
/// * `query` - the query to be applied
/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`.
///
/// # Returns
/// A vector of `tuples` containing the range of the matches in the source code and the corresponding mapping for the tags (to code snippets).
/// By default it returns the range of the outermost node for each query match.
/// If `replace_node` is provided in the rule, it returns the range of the node corresponding to that tag.
pub(crate) fn get_match_for_query(
node: &Node, source_code: &str, query: &Query, recursive: bool,
) -> Option<Match> {
if let Some(m) =
get_all_matches_for_query(node, source_code.to_string(), query, recursive, None, None).first()
{
return Some(m.clone());
}
None
}

/// Applies the query upon the given `node`, and gets the first match
/// # Arguments
/// * `node` - the root node to apply the query upon
Expand Down

0 comments on commit 10a862b

Please sign in to comment.