From 05750a3bde76e0e205f64b03e0a4f9df9ee45c54 Mon Sep 17 00:00:00 2001
From: NAITOH Jun <naitoh@gmail.com>
Date: Sat, 17 Aug 2024 17:35:29 +0900
Subject: [PATCH] add #entity_expansion_limit=, #entity_expansion_text_limit=

## Why?
See: https://github.com/ruby/rexml/issues/192
---
 lib/rexml/document.rb             |  7 +++++-
 lib/rexml/parsers/baseparser.rb   | 14 ++++++++++--
 lib/rexml/parsers/pullparser.rb   |  8 +++++++
 lib/rexml/parsers/sax2parser.rb   |  8 +++++++
 lib/rexml/parsers/streamparser.rb |  8 +++++++
 test/test_document.rb             | 12 ++++------
 test/test_pullparser.rb           | 27 +++++++---------------
 test/test_sax.rb                  | 27 +++++++---------------
 test/test_stream.rb               | 37 ++++++++++++-------------------
 9 files changed, 76 insertions(+), 72 deletions(-)

diff --git a/lib/rexml/document.rb b/lib/rexml/document.rb
index b1caa020..00057bb2 100644
--- a/lib/rexml/document.rb
+++ b/lib/rexml/document.rb
@@ -91,6 +91,7 @@ class Document < Element
     #
     def initialize( source = nil, context = {} )
       @entity_expansion_count = 0
+      @entity_expansion_limit = Security.entity_expansion_limit
       super()
       @context = context
       return if source.nil?
@@ -434,11 +435,15 @@ def Document::entity_expansion_text_limit
 
     def record_entity_expansion
       @entity_expansion_count += 1
-      if @entity_expansion_count > Security.entity_expansion_limit
+      if @entity_expansion_count > @entity_expansion_limit
         raise "number of entity expansions exceeded, processing aborted."
       end
     end
 
+    def entity_expansion_limit=( limit )
+      @entity_expansion_limit = limit
+    end
+
     def document
       self
     end
diff --git a/lib/rexml/parsers/baseparser.rb b/lib/rexml/parsers/baseparser.rb
index d11c2766..11fd5bb1 100644
--- a/lib/rexml/parsers/baseparser.rb
+++ b/lib/rexml/parsers/baseparser.rb
@@ -164,6 +164,8 @@ def initialize( source )
         @listeners = []
         @prefixes = Set.new
         @entity_expansion_count = 0
+        @entity_expansion_limit = Security.entity_expansion_limit
+        @entity_expansion_text_limit = Security.entity_expansion_text_limit
       end
 
       def add_listener( listener )
@@ -585,7 +587,7 @@ def unnormalize( string, entities=nil, filter=nil )
               end
               re = Private::DEFAULT_ENTITIES_PATTERNS[entity_reference] || /&#{entity_reference};/
               rv.gsub!( re, entity_value )
-              if rv.bytesize > Security.entity_expansion_text_limit
+              if rv.bytesize > @entity_expansion_text_limit
                 raise "entity expansion has grown too large"
               end
             else
@@ -598,6 +600,14 @@ def unnormalize( string, entities=nil, filter=nil )
         rv
       end
 
+      def entity_expansion_limit=( limit )
+        @entity_expansion_limit = limit
+      end
+
+      def entity_expansion_text_limit=( limit )
+        @entity_expansion_text_limit = limit
+      end
+
       private
       def add_namespace(prefix, uri)
         @namespaces_restore_stack.last[prefix] = @namespaces[prefix]
@@ -627,7 +637,7 @@ def pop_namespaces_restore
 
       def record_entity_expansion(delta=1)
         @entity_expansion_count += delta
-        if @entity_expansion_count > Security.entity_expansion_limit
+        if @entity_expansion_count > @entity_expansion_limit
           raise "number of entity expansions exceeded, processing aborted."
         end
       end
diff --git a/lib/rexml/parsers/pullparser.rb b/lib/rexml/parsers/pullparser.rb
index 36b45953..a331eff5 100644
--- a/lib/rexml/parsers/pullparser.rb
+++ b/lib/rexml/parsers/pullparser.rb
@@ -51,6 +51,14 @@ def entity_expansion_count
         @parser.entity_expansion_count
       end
 
+      def entity_expansion_limit=( limit )
+        @parser.entity_expansion_limit = limit
+      end
+
+      def entity_expansion_text_limit=( limit )
+        @parser.entity_expansion_text_limit = limit
+      end
+
       def each
         while has_next?
           yield self.pull
diff --git a/lib/rexml/parsers/sax2parser.rb b/lib/rexml/parsers/sax2parser.rb
index cec9d2fc..5452d4b8 100644
--- a/lib/rexml/parsers/sax2parser.rb
+++ b/lib/rexml/parsers/sax2parser.rb
@@ -26,6 +26,14 @@ def entity_expansion_count
         @parser.entity_expansion_count
       end
 
+      def entity_expansion_limit=( limit )
+        @parser.entity_expansion_limit = limit
+      end
+
+      def entity_expansion_text_limit=( limit )
+        @parser.entity_expansion_text_limit = limit
+      end
+
       def add_listener( listener )
         @parser.add_listener( listener )
       end
diff --git a/lib/rexml/parsers/streamparser.rb b/lib/rexml/parsers/streamparser.rb
index 7781fe44..6c64d978 100644
--- a/lib/rexml/parsers/streamparser.rb
+++ b/lib/rexml/parsers/streamparser.rb
@@ -18,6 +18,14 @@ def entity_expansion_count
         @parser.entity_expansion_count
       end
 
+      def entity_expansion_limit=( limit )
+        @parser.entity_expansion_limit = limit
+      end
+
+      def entity_expansion_text_limit=( limit )
+        @parser.entity_expansion_text_limit = limit
+      end
+
       def parse
         # entity string
         while true
diff --git a/test/test_document.rb b/test/test_document.rb
index 25a8828f..bac5afb7 100644
--- a/test/test_document.rb
+++ b/test/test_document.rb
@@ -32,12 +32,10 @@ def test_new
 
     class EntityExpansionLimitTest < Test::Unit::TestCase
       def setup
-        @default_entity_expansion_limit = REXML::Security.entity_expansion_limit
         @default_entity_expansion_text_limit = REXML::Security.entity_expansion_text_limit
       end
 
       def teardown
-        REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
         REXML::Security.entity_expansion_text_limit = @default_entity_expansion_text_limit
       end
 
@@ -64,9 +62,8 @@ def test_have_value
             doc.root.children.first.value
           end
 
-          REXML::Security.entity_expansion_limit = 100
-          assert_equal(100, REXML::Security.entity_expansion_limit)
           doc = REXML::Document.new(xml)
+          doc.entity_expansion_limit = 100
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             doc.root.children.first.value
           end
@@ -95,9 +92,8 @@ def test_empty_value
             doc.root.children.first.value
           end
 
-          REXML::Security.entity_expansion_limit = 100
-          assert_equal(100, REXML::Security.entity_expansion_limit)
           doc = REXML::Document.new(xml)
+          doc.entity_expansion_limit = 100
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             doc.root.children.first.value
           end
@@ -118,12 +114,12 @@ def test_with_default_entity
 </member>
 XML
 
-          REXML::Security.entity_expansion_limit = 4
           doc = REXML::Document.new(xml)
+          doc.entity_expansion_limit = 4
           assert_equal("\na\na a\n<\n", doc.root.children.first.value)
 
-          REXML::Security.entity_expansion_limit = 3
           doc = REXML::Document.new(xml)
+          doc.entity_expansion_limit = 3
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             doc.root.children.first.value
           end
diff --git a/test/test_pullparser.rb b/test/test_pullparser.rb
index 005a106a..bdf8be17 100644
--- a/test/test_pullparser.rb
+++ b/test/test_pullparser.rb
@@ -157,16 +157,6 @@ def test_peek
     end
 
     class EntityExpansionLimitTest < Test::Unit::TestCase
-      def setup
-        @default_entity_expansion_limit = REXML::Security.entity_expansion_limit
-        @default_entity_expansion_text_limit = REXML::Security.entity_expansion_text_limit
-      end
-
-      def teardown
-        REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
-        REXML::Security.entity_expansion_text_limit = @default_entity_expansion_text_limit
-      end
-
       class GeneralEntityTest < self
         def test_have_value
           source = <<-XML
@@ -206,14 +196,13 @@ def test_empty_value
 </member>
           XML
 
-          REXML::Security.entity_expansion_limit = 100000
           parser = REXML::Parsers::PullParser.new(source)
+          parser.entity_expansion_limit = 100000
           while parser.has_next?
             parser.pull
           end
           assert_equal(11111, parser.entity_expansion_count)
 
-          REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
           parser = REXML::Parsers::PullParser.new(source)
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             while parser.has_next?
@@ -221,7 +210,7 @@ def test_empty_value
             end
           end
           assert do
-            parser.entity_expansion_count > @default_entity_expansion_limit
+            parser.entity_expansion_count > REXML::Security.entity_expansion_limit
           end
         end
 
@@ -239,14 +228,14 @@ def test_with_default_entity
 </member>
           XML
 
-          REXML::Security.entity_expansion_limit = 4
           parser = REXML::Parsers::PullParser.new(source)
+          parser.entity_expansion_limit = 4
           while parser.has_next?
             parser.pull
           end
 
-          REXML::Security.entity_expansion_limit = 3
           parser = REXML::Parsers::PullParser.new(source)
+          parser.entity_expansion_limit = 3
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             while parser.has_next?
               parser.pull
@@ -255,7 +244,7 @@ def test_with_default_entity
         end
 
         def test_with_only_default_entities
-          member_value = "&lt;p&gt;#{'A' * @default_entity_expansion_text_limit}&lt;/p&gt;"
+          member_value = "&lt;p&gt;#{'A' * REXML::Security.entity_expansion_text_limit}&lt;/p&gt;"
           source = <<-XML
 <?xml version="1.0" encoding="UTF-8"?>
 <member>
@@ -276,11 +265,11 @@ def test_with_only_default_entities
             end
           end
 
-          expected_value = "<p>#{'A' * @default_entity_expansion_text_limit}</p>"
+          expected_value = "<p>#{'A' * REXML::Security.entity_expansion_text_limit}</p>"
           assert_equal(expected_value, events['member'].strip)
           assert_equal(0, parser.entity_expansion_count)
           assert do
-            events['member'].bytesize > @default_entity_expansion_text_limit
+            events['member'].bytesize > REXML::Security.entity_expansion_text_limit
           end
         end
 
@@ -296,8 +285,8 @@ def test_entity_expansion_text_limit
 <member>&a;</member>
           XML
 
-          REXML::Security.entity_expansion_text_limit = 90
           parser = REXML::Parsers::PullParser.new(source)
+          parser.entity_expansion_text_limit = 90
           events = {}
           element_name = ''
           while parser.has_next?
diff --git a/test/test_sax.rb b/test/test_sax.rb
index ae17e364..6aaeb618 100644
--- a/test/test_sax.rb
+++ b/test/test_sax.rb
@@ -100,16 +100,6 @@ def test_sax2
     end
 
     class EntityExpansionLimitTest < Test::Unit::TestCase
-      def setup
-        @default_entity_expansion_limit = REXML::Security.entity_expansion_limit
-        @default_entity_expansion_text_limit = REXML::Security.entity_expansion_text_limit
-      end
-
-      def teardown
-        REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
-        REXML::Security.entity_expansion_text_limit = @default_entity_expansion_text_limit
-      end
-
       class GeneralEntityTest < self
         def test_have_value
           source = <<-XML
@@ -147,18 +137,17 @@ def test_empty_value
 </member>
           XML
 
-          REXML::Security.entity_expansion_limit = 100000
           sax = REXML::Parsers::SAX2Parser.new(source)
+          sax.entity_expansion_limit = 100000
           sax.parse
           assert_equal(11111, sax.entity_expansion_count)
 
-          REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
           sax = REXML::Parsers::SAX2Parser.new(source)
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             sax.parse
           end
           assert do
-            sax.entity_expansion_count > @default_entity_expansion_limit
+            sax.entity_expansion_count > REXML::Security.entity_expansion_limit
           end
         end
 
@@ -176,19 +165,19 @@ def test_with_default_entity
 </member>
           XML
 
-          REXML::Security.entity_expansion_limit = 4
           sax = REXML::Parsers::SAX2Parser.new(source)
+          sax.entity_expansion_limit = 4
           sax.parse
 
-          REXML::Security.entity_expansion_limit = 3
           sax = REXML::Parsers::SAX2Parser.new(source)
+          sax.entity_expansion_limit = 3
           assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
             sax.parse
           end
         end
 
         def test_with_only_default_entities
-          member_value = "&lt;p&gt;#{'A' * @default_entity_expansion_text_limit}&lt;/p&gt;"
+          member_value = "&lt;p&gt;#{'A' * REXML::Security.entity_expansion_text_limit}&lt;/p&gt;"
           source = <<-XML
 <?xml version="1.0" encoding="UTF-8"?>
 <member>
@@ -203,11 +192,11 @@ def test_with_only_default_entities
           end
           sax.parse
 
-          expected_value = "<p>#{'A' * @default_entity_expansion_text_limit}</p>"
+          expected_value = "<p>#{'A' * REXML::Security.entity_expansion_text_limit}</p>"
           assert_equal(expected_value, text_value.strip)
           assert_equal(0, sax.entity_expansion_count)
           assert do
-            text_value.bytesize > @default_entity_expansion_text_limit
+            text_value.bytesize > REXML::Security.entity_expansion_text_limit
           end
         end
 
@@ -223,8 +212,8 @@ def test_entity_expansion_text_limit
 <member>&a;</member>
           XML
 
-          REXML::Security.entity_expansion_text_limit = 90
           sax = REXML::Parsers::SAX2Parser.new(source)
+          sax.entity_expansion_text_limit = 90
           text_size = nil
           sax.listen(:characters, ["member"]) do |text|
             text_size = text.size
diff --git a/test/test_stream.rb b/test/test_stream.rb
index 782066c2..7917760a 100644
--- a/test/test_stream.rb
+++ b/test/test_stream.rb
@@ -126,16 +126,6 @@ def text(text)
   end
 
   class EntityExpansionLimitTest < Test::Unit::TestCase
-    def setup
-      @default_entity_expansion_limit = REXML::Security.entity_expansion_limit
-      @default_entity_expansion_text_limit = REXML::Security.entity_expansion_text_limit
-    end
-
-    def teardown
-      REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
-      REXML::Security.entity_expansion_text_limit = @default_entity_expansion_text_limit
-    end
-
     def test_have_value
       source = <<-XML
 <?xml version="1.0" encoding="UTF-8"?>
@@ -172,18 +162,17 @@ def test_empty_value
       XML
 
       listener = MyListener.new
-      REXML::Security.entity_expansion_limit = 100000
       parser = REXML::Parsers::StreamParser.new( source, listener )
+      parser.entity_expansion_limit = 100000
       parser.parse
       assert_equal(11111, parser.entity_expansion_count)
 
-      REXML::Security.entity_expansion_limit = @default_entity_expansion_limit
       parser = REXML::Parsers::StreamParser.new( source, listener )
       assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
         parser.parse
       end
       assert do
-        parser.entity_expansion_count > @default_entity_expansion_limit
+        parser.entity_expansion_count > REXML::Security.entity_expansion_limit
       end
     end
 
@@ -202,17 +191,19 @@ def test_with_default_entity
       XML
 
       listener = MyListener.new
-      REXML::Security.entity_expansion_limit = 4
-      REXML::Document.parse_stream(source, listener)
+      parser = REXML::Parsers::StreamParser.new( source, listener )
+      parser.entity_expansion_limit = 4
+      parser.parse
 
-      REXML::Security.entity_expansion_limit = 3
+      parser = REXML::Parsers::StreamParser.new( source, listener )
+      parser.entity_expansion_limit = 3
       assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do
-        REXML::Document.parse_stream(source, listener)
+        parser.parse
       end
     end
 
     def test_with_only_default_entities
-      member_value = "&lt;p&gt;#{'A' * @default_entity_expansion_text_limit}&lt;/p&gt;"
+      member_value = "&lt;p&gt;#{'A' * REXML::Security.entity_expansion_text_limit}&lt;/p&gt;"
       source = <<-XML
 <?xml version="1.0" encoding="UTF-8"?>
 <member>
@@ -231,11 +222,11 @@ def text(text)
       parser = REXML::Parsers::StreamParser.new( source, listener )
       parser.parse
 
-      expected_value = "<p>#{'A' * @default_entity_expansion_text_limit}</p>"
+      expected_value = "<p>#{'A' * REXML::Security.entity_expansion_text_limit}</p>"
       assert_equal(expected_value, listener.text_value.strip)
       assert_equal(0, parser.entity_expansion_count)
       assert do
-        listener.text_value.bytesize > @default_entity_expansion_text_limit
+        listener.text_value.bytesize > REXML::Security.entity_expansion_text_limit
       end
     end
 
@@ -259,9 +250,9 @@ def text(text)
         end
       end
       listener.text_value = ""
-      REXML::Security.entity_expansion_text_limit = 90
-      REXML::Document.parse_stream(source, listener)
-
+      parser = REXML::Parsers::StreamParser.new( source, listener )
+      parser.entity_expansion_text_limit = 90
+      parser.parse
       assert_equal(90, listener.text_value.size)
     end
   end