Skip to content

Commit

Permalink
make generate script output formatted rust code
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar committed Aug 14, 2023
1 parent 6dbba19 commit 50ec57d
Showing 1 changed file with 40 additions and 32 deletions.
72 changes: 40 additions & 32 deletions scripts/mapgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def variant(num):
return 'Fold::Two'
elif num == 3:
return 'Fold::Three'

txt = open('./scripts/CaseFolding.txt')

def replacement(chars):
chars_len = len(chars)
inside = ', '.join(["'\\u{%04x}'" % c for c in chars])
return '%s(%s,)' % (variant(chars_len), inside)
return '%s(%s)' % (variant(chars_len), inside)

def apply_constant_offset(offset_from, offset_to):
if offset_to < offset_from:
Expand All @@ -35,14 +35,14 @@ def __init__(self, map_from, map_tos):
self.end = map_from
self.map_tos = map_tos
self.every_other = None

def limit_to_range(self, min_relevant, max_relevant):

if self.end < min_relevant: return None
if self.start > max_relevant: return None

if self.start >= min_relevant and self.end <= max_relevant: return self

ret = Run(self.start, [m for m in self.map_tos])
ret.end = self.end
ret.every_other = self.every_other
Expand All @@ -54,14 +54,14 @@ def limit_to_range(self, min_relevant, max_relevant):
ret.map_tos[0] += diff
if ret.end > max_relevant:
ret.end = max_relevant

return ret

def expand_into(self, map_from, map_tos):
if len(self.map_tos) != 1 or len(map_tos) != 1:
# Do not attempt to combine if we are not mapping to one character. Those do not follow a simple pattern.
return False

if self.every_other!=True and self.end + 1 == map_from and map_tos[0] == self.map_tos[0] + (map_from - self.start):
self.end += 1
self.every_other = False
Expand All @@ -70,7 +70,7 @@ def expand_into(self, map_from, map_tos):
self.end += 2
self.every_other = True
return True

return False

# When dumping ranges, we avoid using range literals to maintain compatibility with old rustcs
Expand All @@ -92,13 +92,19 @@ def remove_useless_comparison(case_line):
else:
rs.write(" %s => return %s,\n" % (format_range_edge(self.start), replacement(self.map_tos)))
elif self.every_other != True:
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => from.%s,\n" % (format_range_edge(self.start), format_range_edge(self.end), apply_constant_offset(self.start, self.map_tos[0]))),)
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => from.%s,\n" % (format_range_edge(self.start), format_range_edge(self.end), apply_constant_offset(self.start, self.map_tos[0]))),)
elif self.map_tos[0] - self.start == 1 and self.start%2==0:
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => (from | 1),\n" % (format_range_edge(self.start), format_range_edge(self.end))))
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => from | 1,\n" % (format_range_edge(self.start), format_range_edge(self.end))))
elif self.map_tos[0] - self.start == 1 and self.start%2==1:
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => ((from+1) & !1),\n" % (format_range_edge(self.start), format_range_edge(self.end))))
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => (from + 1) & !1,\n" % (format_range_edge(self.start), format_range_edge(self.end))))
else:
rs.write(remove_useless_comparison(" x @ _ if %s <= x && x <= %s => if (from & 1) == %s { from.%s } else { from },\n" % (format_range_edge(self.start), format_range_edge(self.end), self.start%2, apply_constant_offset(self.start, self.map_tos[0]))))
rs.write(" x @ _ if %s <= x && x <= %s => {\n" % (format_range_edge(self.start), format_range_edge(self.end)))
rs.write(" if (from & 1) == %s {\n" % (self.start % 2))
rs.write(" from.%s\n" % apply_constant_offset(self.start, self.map_tos[0]))
rs.write(" } else {\n")
rs.write(" from\n")
rs.write(" }\n")
rs.write(" }\n")

runs = []
singlet_runs = [] # for test generation
Expand All @@ -109,7 +115,7 @@ def remove_useless_comparison(case_line):
if len(parts) > 2 and parts[1] in 'CF':
map_from = int(parts[0], 16)
map_tos = [int(char, 16) for char in parts[2].split(' ')]

if run_in_progress and run_in_progress.expand_into(map_from, map_tos):
pass
else:
Expand All @@ -119,7 +125,7 @@ def remove_useless_comparison(case_line):
runs.append(run_in_progress)

high_runs = [r for r in runs if r.end > 0x2CFF]

small_run_chunks = [] # Each element of this corresponds to a high byte being mapped from
for high_byte in range(0, 0x2D):
minimum_relevant = (high_byte<<8)
Expand Down Expand Up @@ -152,27 +158,26 @@ def remove_useless_comparison(case_line):
rs.write(' let low_byte = (from & 0xff) as u8;\n');
rs.write(' let single_char: u16 = match high_byte {\n')
for (high_byte, runs) in enumerate(small_run_chunks):
rs.write(" 0x%02x => {\n" % high_byte);
rs.write(" 0x%02x => " % high_byte);
if len(runs)==0:
rs.write(' from\n')
rs.write('from,\n')
else:
rs.write(" match low_byte {\n")
rs.write("match low_byte {\n")
for r in runs:
rs.write(' ')
rs.write(' ')
r.dump(match_on_low_byte = True)
rs.write(" _ => from\n")
rs.write(" }\n");
rs.write(" }\n")
rs.write(' _ => from \n')
rs.write(" _ => from,\n")
rs.write(" },\n");
rs.write(' _ => from,\n')
rs.write(' };\n');
rs.write(' Fold::One( char::from_u32(single_char as u32).unwrap_or(orig) )\n')
rs.write(' Fold::One(char::from_u32(single_char as u32).unwrap_or(orig))\n')
rs.write(' } else {\n');
rs.write(' let single_char: u32 = match from {\n')
for r in high_runs:
r.dump()
rs.write(' _ => from\n')
rs.write(' _ => from,\n')
rs.write(' };\n')
rs.write(' Fold::One( char::from_u32(single_char).unwrap_or(orig) )\n')
rs.write(' Fold::One(char::from_u32(single_char).unwrap_or(orig))\n')
rs.write(' }\n')
rs.write('}\n')

Expand All @@ -186,17 +191,20 @@ def remove_useless_comparison(case_line):
rs.write(' let single_char = match orig as u32 {\n');
for r in singlet_runs:
r.dump()
rs.write(' _ => orig as u32\n')
rs.write(' _ => orig as u32,\n')
rs.write(' };\n')
rs.write(' Fold::One( char::from_u32(single_char).unwrap() )\n')
rs.write(' }\n')
rs.write(' \n')
rs.write(' Fold::One(char::from_u32(single_char).unwrap())\n')
rs.write(' }\n\n')
rs.write(' for c_index in 0..%d {\n' % test_max)
rs.write(' if let Some(c) = char::from_u32(c_index) {\n');
rs.write(' let reference: Vec<char> = lookup_naive(c).collect();\n')
rs.write(' let actual: Vec<char> = lookup(c).collect();\n')
rs.write(' if actual != reference {\n')
rs.write(' assert!(false, "case-folding {:?} (#0x{:04x}) failed: Expected {:?}, got {:?}", c, c_index, reference, actual);\n')
rs.write(' assert!(\n')
rs.write(' false,\n')
rs.write(' "case-folding {:?} (#0x{:04x}) failed: Expected {:?}, got {:?}",\n')
rs.write(' c, c_index, reference, actual\n')
rs.write(' );\n')
rs.write(' }\n')
rs.write(' }\n')
rs.write(' }\n')
Expand Down

0 comments on commit 50ec57d

Please sign in to comment.