Skip to content

Commit

Permalink
Rollup merge of #107110 - strega-nil:mbtwc-wctmb, r=ChrisDenton
Browse files Browse the repository at this point in the history
[stdio][windows] Use MBTWC and WCTMB

`MultiByteToWideChar` and `WideCharToMultiByte` are extremely well optimized, and therefore should probably be used when we know we can (specifically in the Windows stdio stuff).

Fixes #107092
  • Loading branch information
matthiaskrgr authored Feb 27, 2023
2 parents cf04603 + 7f25580 commit 3fe4023
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 29 deletions.
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
all(target_vendor = "fortanix", target_env = "sgx"),
feature(slice_index_methods, coerce_unsized, sgx_platform)
)]
#![cfg_attr(windows, feature(round_char_boundary))]
//
// Language features:
#![feature(alloc_error_handler)]
Expand Down
32 changes: 30 additions & 2 deletions library/std/src/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

use crate::ffi::CStr;
use crate::mem;
use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort};
use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort};
use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull};
use crate::ptr;
use core::ffi::NonZero_c_ulong;

use libc::{c_void, size_t, wchar_t};

pub use crate::os::raw::c_int;

#[path = "c/errors.rs"] // c.rs is included from two places so we need to specify this
mod errors;
pub use errors::*;
Expand Down Expand Up @@ -47,16 +49,19 @@ pub type ACCESS_MASK = DWORD;

pub type LPBOOL = *mut BOOL;
pub type LPBYTE = *mut BYTE;
pub type LPCCH = *const CHAR;
pub type LPCSTR = *const CHAR;
pub type LPCWCH = *const WCHAR;
pub type LPCWSTR = *const WCHAR;
pub type LPCVOID = *const c_void;
pub type LPDWORD = *mut DWORD;
pub type LPHANDLE = *mut HANDLE;
pub type LPOVERLAPPED = *mut OVERLAPPED;
pub type LPPROCESS_INFORMATION = *mut PROCESS_INFORMATION;
pub type LPSECURITY_ATTRIBUTES = *mut SECURITY_ATTRIBUTES;
pub type LPSTARTUPINFO = *mut STARTUPINFO;
pub type LPSTR = *mut CHAR;
pub type LPVOID = *mut c_void;
pub type LPCVOID = *const c_void;
pub type LPWCH = *mut WCHAR;
pub type LPWIN32_FIND_DATAW = *mut WIN32_FIND_DATAW;
pub type LPWSADATA = *mut WSADATA;
Expand Down Expand Up @@ -132,6 +137,10 @@ pub const MAX_PATH: usize = 260;

pub const FILE_TYPE_PIPE: u32 = 3;

pub const CP_UTF8: DWORD = 65001;
pub const MB_ERR_INVALID_CHARS: DWORD = 0x08;
pub const WC_ERR_INVALID_CHARS: DWORD = 0x80;

#[repr(C)]
#[derive(Copy)]
pub struct WIN32_FIND_DATAW {
Expand Down Expand Up @@ -1147,6 +1156,25 @@ extern "system" {
lpFilePart: *mut LPWSTR,
) -> DWORD;
pub fn GetFileAttributesW(lpFileName: LPCWSTR) -> DWORD;

pub fn MultiByteToWideChar(
CodePage: UINT,
dwFlags: DWORD,
lpMultiByteStr: LPCCH,
cbMultiByte: c_int,
lpWideCharStr: LPWSTR,
cchWideChar: c_int,
) -> c_int;
pub fn WideCharToMultiByte(
CodePage: UINT,
dwFlags: DWORD,
lpWideCharStr: LPCWCH,
cchWideChar: c_int,
lpMultiByteStr: LPSTR,
cbMultiByte: c_int,
lpDefaultChar: LPCCH,
lpUsedDefaultChar: LPBOOL,
) -> c_int;
}

#[link(name = "ws2_32")]
Expand Down
74 changes: 47 additions & 27 deletions library/std/src/sys/windows/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,27 @@ fn write(
}

fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
debug_assert!(!utf8.is_empty());

let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
let mut len_utf16 = 0;
for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
*dest = MaybeUninit::new(chr);
len_utf16 += 1;
}
// Safety: We've initialized `len_utf16` values.
let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) };
let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())];

let utf16: &[u16] = unsafe {
// Note that this theoretically checks validity twice in the (most common) case
// where the underlying byte sequence is valid utf-8 (given the check in `write()`).
let result = c::MultiByteToWideChar(
c::CP_UTF8, // CodePage
c::MB_ERR_INVALID_CHARS, // dwFlags
utf8.as_ptr() as c::LPCCH, // lpMultiByteStr
utf8.len() as c::c_int, // cbMultiByte
utf16.as_mut_ptr() as c::LPWSTR, // lpWideCharStr
utf16.len() as c::c_int, // cchWideChar
);
assert!(result != 0, "Unexpected error in MultiByteToWideChar");

// Safety: MultiByteToWideChar initializes `result` values.
MaybeUninit::slice_assume_init_ref(&utf16[..result as usize])
};

let mut written = write_u16s(handle, &utf16)?;

Expand All @@ -189,8 +202,8 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
// a missing surrogate can be produced (and also because of the UTF-8 validation above),
// write the missing surrogate out now.
// Buffering it would mean we have to lie about the number of bytes written.
let first_char_remaining = utf16[written];
if first_char_remaining >= 0xDCEE && first_char_remaining <= 0xDFFF {
let first_code_unit_remaining = utf16[written];
if first_code_unit_remaining >= 0xDCEE && first_code_unit_remaining <= 0xDFFF {
// low surrogate
// We just hope this works, and give up otherwise
let _ = write_u16s(handle, &utf16[written..written + 1]);
Expand All @@ -212,6 +225,7 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
}

fn write_u16s(handle: c::HANDLE, data: &[u16]) -> io::Result<usize> {
debug_assert!(data.len() < u32::MAX as usize);
let mut written = 0;
cvt(unsafe {
c::WriteConsoleW(
Expand Down Expand Up @@ -365,26 +379,32 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit<u16>]) -> io::Result<usiz
Ok(amount as usize)
}

#[allow(unused)]
fn utf16_to_utf8(utf16: &[u16], utf8: &mut [u8]) -> io::Result<usize> {
let mut written = 0;
for chr in char::decode_utf16(utf16.iter().cloned()) {
match chr {
Ok(chr) => {
chr.encode_utf8(&mut utf8[written..]);
written += chr.len_utf8();
}
Err(_) => {
// We can't really do any better than forget all data and return an error.
return Err(io::const_io_error!(
io::ErrorKind::InvalidData,
"Windows stdin in console mode does not support non-UTF-16 input; \
encountered unpaired surrogate",
));
}
}
debug_assert!(utf16.len() <= c::c_int::MAX as usize);
debug_assert!(utf8.len() <= c::c_int::MAX as usize);

let result = unsafe {
c::WideCharToMultiByte(
c::CP_UTF8, // CodePage
c::WC_ERR_INVALID_CHARS, // dwFlags
utf16.as_ptr(), // lpWideCharStr
utf16.len() as c::c_int, // cchWideChar
utf8.as_mut_ptr() as c::LPSTR, // lpMultiByteStr
utf8.len() as c::c_int, // cbMultiByte
ptr::null(), // lpDefaultChar
ptr::null_mut(), // lpUsedDefaultChar
)
};
if result == 0 {
// We can't really do any better than forget all data and return an error.
Err(io::const_io_error!(
io::ErrorKind::InvalidData,
"Windows stdin in console mode does not support non-UTF-16 input; \
encountered unpaired surrogate",
))
} else {
Ok(result as usize)
}
Ok(written)
}

impl IncompleteUtf8 {
Expand Down

0 comments on commit 3fe4023

Please sign in to comment.