diff --git a/yara-x/src/compiler/context.rs b/yara-x/src/compiler/context.rs index 14abe4f58..6aac69cf9 100644 --- a/yara-x/src/compiler/context.rs +++ b/yara-x/src/compiler/context.rs @@ -49,8 +49,8 @@ pub(in crate::compiler) struct Context<'a, 'src, 'sym> { /// Rule that is being compiled. pub current_rule: &'a RuleInfo, - // IR nodes for patterns defined in the rule being compiled. - pub current_rule_patterns: &'a mut Vec>, + /// IR nodes for patterns defined in the rule being compiled. + pub current_rule_patterns: &'a mut FxHashMap>, /// Warnings generated during the compilation. pub warnings: &'a mut Vec, @@ -82,16 +82,6 @@ pub(in crate::compiler) struct Context<'a, 'src, 'sym> { } impl<'a, 'src, 'sym> Context<'a, 'src, 'sym> { - /// Given an [`IdentId`] returns the identifier as `&str`. - /// - /// # Panics - /// - /// Panics if no identifier has the provided [`IdentId`]. - #[inline] - pub fn resolve_ident(&self, ident_id: IdentId) -> &str { - self.ident_pool.get(ident_id).unwrap() - } - /// Returns a [`RuleInfo`] given its [`RuleId`]. /// /// # Panics @@ -120,19 +110,15 @@ impl<'a, 'src, 'sym> Context<'a, 'src, 'sym> { .expect("identifier must be at least 1 character long") )); - for (ident_id, pattern_id) in &self.current_rule.patterns { + for (pattern_id, pattern) in self.current_rule_patterns.iter() { // Ignore the first character (`$`, `#`, `@` or `!`) while // comparing the identifiers. - if self.resolve_ident(*ident_id)[1..] == ident[1..] { + if pattern.identifier()[1..] == ident[1..] { return *pattern_id; } } - panic!( - "rule `{}` does not have pattern `{}` ", - self.resolve_ident(self.current_rule.ident_id), - ident - ); + panic!("pattern `{}` not found", ident); } /// Given a pattern identifier (e.g. `$a`, `#a`, `@a`) search for it in @@ -154,11 +140,12 @@ impl<'a, 'src, 'sym> Context<'a, 'src, 'sym> { .expect("identifier must be at least 1 character long") )); - for p in self.current_rule_patterns.iter_mut() { - if p.identifier()[1..] == ident[1..] { - return p; + for (_, pattern) in self.current_rule_patterns.iter_mut() { + if pattern.identifier()[1..] == ident[1..] { + return pattern; } } + panic!("pattern `{}` not found", ident); } diff --git a/yara-x/src/compiler/ir/ast2ir.rs b/yara-x/src/compiler/ir/ast2ir.rs index 78b20d03d..46d2983d3 100644 --- a/yara-x/src/compiler/ir/ast2ir.rs +++ b/yara-x/src/compiler/ir/ast2ir.rs @@ -617,8 +617,8 @@ fn of_expr_from_ast( // `x of them`, `x of ($a*, $b)` ast::OfItems::PatternSet(pattern_set) => { let pattern_ids = pattern_set_from_ast(ctx, pattern_set); - let num_items = pattern_ids.len(); - (OfItems::PatternSet(pattern_ids), num_items) + let num_patterns = pattern_ids.len(); + (OfItems::PatternSet(pattern_ids), num_patterns) } }; @@ -1030,28 +1030,36 @@ fn pattern_set_from_ast( ctx: &mut Context, pattern_set: &ast::PatternSet, ) -> Vec { - match pattern_set { + let pattern_ids = match pattern_set { // `x of them` ast::PatternSet::Them => ctx - .current_rule - .patterns + .current_rule_patterns .iter() - .map(|(_, pattern_id)| *pattern_id) + .map(|(pattern_id, _)| *pattern_id) .collect(), // `x of ($a*, $b)` - ast::PatternSet::Set(ref set_patterns) => { + ast::PatternSet::Set(ref set) => { let mut pattern_ids = Vec::new(); - for (ident_id, pattern_id) in &ctx.current_rule.patterns { - let ident = ctx.resolve_ident(*ident_id); + for (pattern_id, pattern) in ctx.current_rule_patterns.iter() { // Iterate over the patterns in the set (e.g: $foo, $foo*) and // check if some of them matches the identifier. - if set_patterns.iter().any(|p| p.matches(ident)) { + if set.iter().any(|p| p.matches(pattern.identifier())) { pattern_ids.push(*pattern_id); } } pattern_ids } + }; + + // Make all the patterns in the set non-anchorable. + for pattern_id in pattern_ids.iter() { + ctx.current_rule_patterns + .get_mut(pattern_id) + .unwrap() + .make_non_anchorable(); } + + pattern_ids } fn func_call_from_ast( diff --git a/yara-x/src/compiler/mod.rs b/yara-x/src/compiler/mod.rs index 4a87f84e4..6d5931837 100644 --- a/yara-x/src/compiler/mod.rs +++ b/yara-x/src/compiler/mod.rs @@ -559,24 +559,29 @@ impl<'a> Compiler<'a> { self.check_for_existing_identifier(&rule.identifier)?; // Convert the patterns from AST to IR. - let mut patterns = + let patterns = patterns_from_ast(&self.report_builder, rule.patterns.as_ref())?; let num_patterns: usize = patterns.len(); - // Create array with pairs (IdentId, PatternId) that describe - // the patterns in a compiled rule. - let mut ident_and_pattern = Vec::with_capacity(num_patterns); + // Create array with pairs (IdentId, PatternId) + let mut ident_and_pattern_ids = Vec::with_capacity(num_patterns); + + // Create a map (IdentId, Pattern). + let mut patterns_map: FxHashMap = + FxHashMap::default(); for (pattern_id, pattern) in - iter::zip(self.next_pattern_id.successors(), &patterns) + iter::zip(self.next_pattern_id.successors(), patterns) { // Save pattern identifier (e.g: $a) in the pool of identifiers // or reuse the IdentId if the identifier has been used already. - ident_and_pattern.push(( + ident_and_pattern_ids.push(( self.ident_pool.get_or_intern(pattern.identifier()), pattern_id, )); + + patterns_map.insert(pattern_id, pattern); } let rule_id = RuleId(self.rules.len() as i32); @@ -586,7 +591,7 @@ impl<'a> Compiler<'a> { namespace_ident_id: self.current_namespace.ident_id, ident_id: self.ident_pool.get_or_intern(rule.identifier.name), ident_span: rule.identifier.span, - patterns: ident_and_pattern, + patterns: ident_and_pattern_ids, is_global: rule.flags.contains(RuleFlag::Global), is_private: rule.flags.contains(RuleFlag::Private), }); @@ -619,7 +624,7 @@ impl<'a> Compiler<'a> { report_builder: &self.report_builder, rules: &self.rules, current_rule: self.rules.last().unwrap(), - current_rule_patterns: &mut patterns, + current_rule_patterns: &mut patterns_map, wasm_symbols: &self.wasm_symbols, wasm_exports: &self.wasm_exports, warnings: &mut self.warnings, @@ -647,13 +652,12 @@ impl<'a> Compiler<'a> { drop(ctx); - let patterns_with_span = itertools::multizip(( - self.next_pattern_id.successors(), - patterns, + let patterns_with_span = iter::zip( + patterns_map, rule.patterns.iter().flatten().map(|p| p.span()), - )); + ); - for (pattern_id, pattern, span) in patterns_with_span { + for ((pattern_id, pattern), span) in patterns_with_span { self.current_pattern_id = pattern_id; match pattern { Pattern::Literal(pattern) => { diff --git a/yara-x/src/tests/mod.rs b/yara-x/src/tests/mod.rs index 574273afd..acade487c 100644 --- a/yara-x/src/tests/mod.rs +++ b/yara-x/src/tests/mod.rs @@ -1134,6 +1134,19 @@ fn match_at() { b"foobar" ); + rule_true!( + r#" + rule test { + strings: + $a = "foo" + $b = "bar" + condition: + 2 of ($a, $b) or $b at 0 + } + "#, + b"foobar" + ); + rule_false!( r#" rule test {