diff --git a/stdlib/src/collections/string.mojo b/stdlib/src/collections/string.mojo index fbedee4e2b..9c2fb1f4cf 100644 --- a/stdlib/src/collections/string.mojo +++ b/stdlib/src/collections/string.mojo @@ -34,7 +34,12 @@ from utils import ( Formatter, ) from utils.format import ToFormatter -from utils.string_slice import _utf8_byte_type, _StringSliceIter +from utils.string_slice import ( + _utf8_byte_type, + _StringSliceIter, + _unicode_codepoint_utf8_byte_length, + _shift_unicode_to_utf8, +) # ===----------------------------------------------------------------------=== # # ord @@ -107,50 +112,36 @@ fn ord(s: StringSlice) -> Int: fn chr(c: Int) -> String: - """Returns a string based on the given Unicode code point. - - Returns the string representing a character whose code point is the integer - `c`. For example, `chr(97)` returns the string `"a"`. This is the inverse of - the `ord()` function. + """Returns a String based on the given Unicode code point. This is the + inverse of the `ord()` function. Args: c: An integer that represents a code point. Returns: A string containing a single character based on the given code point. + + Examples: + ```mojo + print(chr(97)) # "a" + print(chr(8364)) # "€" + ``` + . """ - # Unicode (represented as UInt32 BE) to UTF-8 conversion : - # 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa a - # 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb a >> 6 | 0b11000000, b | 0b10000000 - # 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc a >> 12 | 0b11100000, b >> 6 | 0b10000000, c | 0b10000000 - # 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc 10dddddd a >> 18 | 0b11110000, b >> 12 | 0b10000000, c >> 6 | 0b10000000, d | 0b10000000 - if (c >> 7) == 0: # This is 1 byte ASCII char - return _chr_ascii(c) + if c < 0b1000_0000: # 1 byte ASCII char + return String(String._buffer_type(c, 0)) - @always_inline - fn _utf8_len(val: Int) -> Int: - debug_assert( - 0 <= val <= 0x10FFFF, "Value is not a valid Unicode code point" - ) - alias sizes = SIMD[DType.int32, 4]( - 0, 0b1111_111, 0b1111_1111_111, 0b1111_1111_1111_1111 - ) - var values = SIMD[DType.int32, 4](val) - var mask = values > sizes - return int(mask.cast[DType.uint8]().reduce_add()) - - var num_bytes = _utf8_len(c) + var num_bytes = _unicode_codepoint_utf8_byte_length(c) var p = UnsafePointer[UInt8].alloc(num_bytes + 1) - var shift = 6 * (num_bytes - 1) - var mask = UInt8(0xFF) >> (num_bytes + 1) - var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) - p.store[width=1](((c >> shift) & mask) | num_bytes_marker) - for i in range(1, num_bytes): - shift -= 6 - p.store[width=1](i, ((c >> shift) & 0b00111111) | 0b10000000) - p.store[width=1](num_bytes, 0) - return String(p.bitcast[UInt8](), num_bytes + 1) + _shift_unicode_to_utf8(p, c, num_bytes) + # TODO: decide whether to use replacement char (�) or raise ValueError + # if not _is_valid_utf8(p, num_bytes): + # debug_assert(False, "Invalid Unicode code point") + # p.free() + # return chr(0xFFFD) + p[num_bytes] = 0 + return String(ptr=p, len=num_bytes + 1) # ===----------------------------------------------------------------------=== # diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index 031e9b7d75..d2a788ea8e 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -31,6 +31,39 @@ alias StaticString = StringSlice[ImmutableStaticLifetime] """An immutable static string slice.""" +fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int: + alias sizes = SIMD[DType.int32, 4](0, 0b0111_1111, 0b0111_1111_1111, 0xFFFF) + return int((sizes < c).cast[DType.uint8]().reduce_add()) + + +fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): + """Shift unicode to utf8 representation. + + ### Unicode (represented as UInt32 BE) to UTF-8 conversion: + - 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa + - a + - 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb + - (a >> 6) | 0b11000000, b | 0b10000000 + - 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc + - (a >> 12) | 0b11100000, (b >> 6) | 0b10000000, c | 0b10000000 + - 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc + 10dddddd + - (a >> 18) | 0b11110000, (b >> 12) | 0b10000000, (c >> 6) | 0b10000000, + d | 0b10000000 + """ + if num_bytes == 1: + ptr[0] = UInt8(c) + return + + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + 1) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | num_bytes_marker + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 + + fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): """UTF-8 byte type.