Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Fix chr impl taking funcs to string_slice and separating #3506

Open
wants to merge 1 commit into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 26 additions & 35 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ from utils import (
Formatter,
)
from utils.format import ToFormatter
from utils.string_slice import _utf8_byte_type, _StringSliceIter
from utils.string_slice import (
_utf8_byte_type,
_StringSliceIter,
_unicode_codepoint_utf8_byte_length,
_shift_unicode_to_utf8,
)

# ===----------------------------------------------------------------------=== #
# ord
Expand Down Expand Up @@ -107,50 +112,36 @@ fn ord(s: StringSlice) -> Int:


fn chr(c: Int) -> String:
"""Returns a string based on the given Unicode code point.

Returns the string representing a character whose code point is the integer
`c`. For example, `chr(97)` returns the string `"a"`. This is the inverse of
the `ord()` function.
"""Returns a String based on the given Unicode code point. This is the
inverse of the `ord()` function.

Args:
c: An integer that represents a code point.

Returns:
A string containing a single character based on the given code point.

Examples:
```mojo
print(chr(97)) # "a"
print(chr(8364)) # "€"
```
.
"""
# Unicode (represented as UInt32 BE) to UTF-8 conversion :
# 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa a
# 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb a >> 6 | 0b11000000, b | 0b10000000
# 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc a >> 12 | 0b11100000, b >> 6 | 0b10000000, c | 0b10000000
# 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc 10dddddd a >> 18 | 0b11110000, b >> 12 | 0b10000000, c >> 6 | 0b10000000, d | 0b10000000

if (c >> 7) == 0: # This is 1 byte ASCII char
return _chr_ascii(c)
if c < 0b1000_0000: # 1 byte ASCII char
return String(String._buffer_type(c, 0))

@always_inline
fn _utf8_len(val: Int) -> Int:
debug_assert(
0 <= val <= 0x10FFFF, "Value is not a valid Unicode code point"
)
alias sizes = SIMD[DType.int32, 4](
0, 0b1111_111, 0b1111_1111_111, 0b1111_1111_1111_1111
)
var values = SIMD[DType.int32, 4](val)
var mask = values > sizes
return int(mask.cast[DType.uint8]().reduce_add())

var num_bytes = _utf8_len(c)
var num_bytes = _unicode_codepoint_utf8_byte_length(c)
var p = UnsafePointer[UInt8].alloc(num_bytes + 1)
var shift = 6 * (num_bytes - 1)
var mask = UInt8(0xFF) >> (num_bytes + 1)
var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes)
p.store[width=1](((c >> shift) & mask) | num_bytes_marker)
for i in range(1, num_bytes):
shift -= 6
p.store[width=1](i, ((c >> shift) & 0b00111111) | 0b10000000)
p.store[width=1](num_bytes, 0)
return String(p.bitcast[UInt8](), num_bytes + 1)
_shift_unicode_to_utf8(p, c, num_bytes)
# TODO: decide whether to use replacement char (�) or raise ValueError
# if not _is_valid_utf8(p, num_bytes):
# debug_assert(False, "Invalid Unicode code point")
# p.free()
# return chr(0xFFFD)
p[num_bytes] = 0
return String(ptr=p, len=num_bytes + 1)


# ===----------------------------------------------------------------------=== #
Expand Down
33 changes: 33 additions & 0 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@ alias StaticString = StringSlice[ImmutableStaticLifetime]
"""An immutable static string slice."""


fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int:
alias sizes = SIMD[DType.int32, 4](0, 0b0111_1111, 0b0111_1111_1111, 0xFFFF)
return int((sizes < c).cast[DType.uint8]().reduce_add())


fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of this argument ordering but I don't know if we have any such sort of convention yet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no convention yet, I could make a case either way: weak preference towards putting num_bytes as the second argument.

"""Shift unicode to utf8 representation.

### Unicode (represented as UInt32 BE) to UTF-8 conversion:
- 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa
- a
- 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb
- (a >> 6) | 0b11000000, b | 0b10000000
- 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc
- (a >> 12) | 0b11100000, (b >> 6) | 0b10000000, c | 0b10000000
- 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc
10dddddd
- (a >> 18) | 0b11110000, (b >> 12) | 0b10000000, (c >> 6) | 0b10000000,
d | 0b10000000
"""
if num_bytes == 1:
ptr[0] = UInt8(c)
return

var shift = 6 * (num_bytes - 1)
var mask = UInt8(0xFF) >> (num_bytes + 1)
var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes)
ptr[0] = ((c >> shift) & mask) | num_bytes_marker
for i in range(1, num_bytes):
shift -= 6
ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000


fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b):
"""UTF-8 byte type.

Expand Down