Skip to content

Commit

Permalink
Merge pull request jruby#8424 from headius/sort_switch_jumps
Browse files Browse the repository at this point in the history
Sort the jump tables based on new values
  • Loading branch information
headius authored Nov 14, 2024
2 parents 6f9df83 + d216986 commit e3de28b
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 39 deletions.
34 changes: 27 additions & 7 deletions core/src/main/java/org/jruby/ir/instructions/BSwitchInstr.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public BSwitchInstr(int[] jumps, Operand[] jumpOperands, Operand operand, Label
super(Operation.B_SWITCH);

// We depend on the jump table being sorted, so ensure that's the case here
assert jumpsAreSorted(jumps);
assert jumpsAreSorted(jumps) : "jump table must be sorted";

// Switch cases must not have an empty "case" value (GH-6440)
assert operand != null : "Switch cases must not have an empty \"case\" value";
Expand Down Expand Up @@ -108,17 +108,37 @@ public void encode(IRWriterEncoder e) {
public static BSwitchInstr decode(IRReaderDecoder d) {
try {
Operand[] jumpOperands = d.decodeOperandArray();
Operand operand = d.decodeOperand();
Label rubyCase = d.decodeLabel();
Label[] targets = d.decodeLabelArray();
Label elseTarget = d.decodeLabel();
Class<?> expectedClass = Class.forName(d.decodeString());

int[] jumps = new int[jumpOperands.length];
for (int i = 0; i < jumps.length; i++) {
Operand operand = jumpOperands[i];
if (operand instanceof Symbol) {
jumps[i] = ((Symbol) operand).getSymbol().getId();
} else if (operand instanceof Fixnum) {
jumps[i] = (int) ((Fixnum) operand).getValue();
Operand jumpOperand = jumpOperands[i];
if (jumpOperand instanceof Symbol) {
jumps[i] = ((Symbol) jumpOperand).getSymbol().getId();
} else if (jumpOperand instanceof Fixnum) {
jumps[i] = (int) ((Fixnum) jumpOperand).getValue();
}
}

return new BSwitchInstr(jumps, jumpOperands, d.decodeOperand(), d.decodeLabel(), d.decodeLabelArray(), d.decodeLabel(), Class.forName(d.decodeString()));
int[] sortedJumps = jumps.clone();
Arrays.sort(sortedJumps);

Operand[] sortedJumpOperands = jumpOperands.clone();
Label[] sortedTargets = targets.clone();

for (int i = 0; i < jumps.length; i++) {
int oldJump = jumps[i];
int newIndex = Arrays.binarySearch(sortedJumps, oldJump);

sortedJumpOperands[newIndex] = jumpOperands[i];
sortedTargets[newIndex] = targets[i];
}

return new BSwitchInstr(sortedJumps, sortedJumpOperands, operand, rubyCase, sortedTargets, elseTarget, expectedClass);
} catch (Exception e) {
// should never happen unless encode was corrupted
Helpers.throwException(e);
Expand Down
127 changes: 95 additions & 32 deletions spec/compiler/general_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def self.name; "interpreter"; end
module PersistenceSpecUtils
include CompilerSpecUtils

def initialize(*x, **y)
super
@persist_runtime = org.jruby.Ruby.newInstance
end

attr_reader :persist_runtime

def run_in_method(src, filename = caller_locations[0].path, line = caller_locations[0].lineno)
run( "def __temp; #{src}; end; __temp", filename, line)
end
Expand All @@ -42,22 +49,34 @@ def self.name; "persistence"; end
private

def encode_decode_run(src, filename, line)
runtime = JRuby.runtime
manager = runtime.getIRManager()

method = JRuby.compile_ir(src, filename, false, line - 1)

top_self = runtime.top_self
# persist with separate runtime
jruby_module = persist_runtime.eval_scriptlet("require 'jruby'; JRuby")
persist_context = persist_runtime.current_context
persist_src = persist_runtime.new_string(src)
persist_filename = persist_runtime.new_string(src)
persist_line = persist_runtime.new_fixnum(line - 1)
method = org.jruby.ext.jruby.JRubyLibrary.compile_ir(
persist_context,
jruby_module,
[persist_src,
persist_filename,
persist_runtime.false,
persist_line].to_java(org.jruby.runtime.builtin.IRubyObject),
org.jruby.runtime.Block::NULL_BLOCK)

# encode and decode
baos = java.io.ByteArrayOutputStream.new
writer = org.jruby.ir.persistence.IRWriterStream.new(baos)
org.jruby.ir.persistence.IRWriter.persist(writer, method)

# interpret with test runtime
runtime = JRuby.runtime
manager = runtime.getIRManager()
top_self = runtime.top_self

reader = org.jruby.ir.persistence.IRReaderStream.new(manager, baos.to_byte_array, filename.to_java)
method = org.jruby.ir.persistence.IRReader.load(manager, reader)

# interpret
interpreter = org.jruby.ir.interpreter.Interpreter.new
interpreter.execute(runtime, method, top_self)
end
Expand Down Expand Up @@ -591,42 +610,86 @@ def foo

it "handles optimized homogeneous case/when" do
run('
case "a"
when "b"
fail
when "a"
1
else
fail
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"].map do |x|
case x
when "a"
1
when "b"
2
when "c"
3
when "d"
4
when "e"
5
when "f"
6
when "g"
7
when "h"
8
when "i"
9
when "j"
10
else
fail
end
end
') do |result|
expect(result).to eq 1
expect(result).to eq [1,2,3,4,5,6,7,8,9,10]
end

run('
case :a
when :b
fail
when :a
1
else
fail
end
[:zxcvbnmzxcvbnm, :qwertyuiopqwertyuiop, :asdfghjklasdfghjkl, :a, :z].map do |x|
case x
when :zxcvbnmzxcvbnm
1
when :qwertyuiopqwertyuiop
2
when :asdfghjklasdfghjkl
3
when :a
4
when :z
5
else
fail
end
end
') do |result|
expect(result).to eq 1
expect(result).to eq [1,2,3,4,5]
end

run('
case 1
when 2
fail
when 1
1
else
fail
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map do |x|
case x
when 1
1
when 2
2
when 3
3
when 4
4
when 5
5
when 6
6
when 7
7
when 8
8
when 9
9
when 10
10
else
fail
end
end
') do |result|
expect(result).to eq 1
expect(result).to eq [1,2,3,4,5,6,7,8,9,10]
end
end

Expand Down

0 comments on commit e3de28b

Please sign in to comment.