From de746455d7c52abe94bb219fdb9dac71976e74cf Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 30 Jan 2024 14:51:35 +0800 Subject: [PATCH] Refactor --- src/main/cpp/src/datetime_parser.cu | 79 +++++++++++++++-------------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/src/main/cpp/src/datetime_parser.cu b/src/main/cpp/src/datetime_parser.cu index d02139cce4..505f9821ad 100644 --- a/src/main/cpp/src/datetime_parser.cu +++ b/src/main/cpp/src/datetime_parser.cu @@ -142,12 +142,14 @@ struct parse_timestamp_string_fn { return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, ParseResult::INVALID); } - auto const d_str = d_strings.element(idx); - - timestamp_components ts_comp{}; - char const* tz_lit_ptr = nullptr; - cudf::size_type tz_lit_len = 0; - switch (parse_string_to_timestamp_us(&ts_comp, &tz_lit_ptr, &tz_lit_len, d_str)) { + auto const d_str = d_strings.element(idx); + auto parse_ret_tuple = parse_string_to_timestamp_us(d_str); + auto ts_comp = thrust::get<0>(parse_ret_tuple); + auto tz_lit_ptr = thrust::get<1>(parse_ret_tuple); + auto tz_lit_len = thrust::get<2>(parse_ret_tuple); + auto result = thrust::get<3>(parse_ret_tuple); + + switch (result) { case ParseResult::INVALID: return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, ParseResult::INVALID); case ParseResult::UNSUPPORTED: @@ -391,12 +393,15 @@ struct parse_timestamp_string_fn { * Parse a string with time zone to a timestamp. * The bool in the returned tuple is false if the parse failed. */ - __device__ inline ParseResult parse_string_to_timestamp_us( - timestamp_components* ts_comp, - char const** parsed_tz_ptr, - cudf::size_type* parsed_tz_length, - cudf::string_view const& timestamp_str) const + __device__ inline thrust::tuple + parse_string_to_timestamp_us(cudf::string_view const& timestamp_str) const { + timestamp_components ts_comp{}; + char const* parsed_tz_ptr = nullptr; + cudf::size_type parsed_tz_length = -1; + auto invalid_ret = + thrust::make_tuple(ts_comp, parsed_tz_ptr, parsed_tz_length, ParseResult::INVALID); + const char* curr_ptr = timestamp_str.data(); const char* end_ptr = curr_ptr + timestamp_str.size_bytes(); @@ -409,7 +414,7 @@ struct parse_timestamp_string_fn { --end_ptr; } - if (curr_ptr == end_ptr) { return ParseResult::INVALID; } + if (curr_ptr == end_ptr) { return invalid_ret; } const char* const bytes = curr_ptr; const cudf::size_type bytes_length = end_ptr - curr_ptr; @@ -441,72 +446,72 @@ struct parse_timestamp_string_fn { i += 3; } else if (i < 2) { if (b == '-') { - if (!is_valid_digits(i, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } segments[i] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i += 1; } else if (0 == i && ':' == b && !year_sign.has_value()) { - if (!is_valid_digits(3, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(3, current_segment_digits)) { return invalid_ret; } segments[3] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i = 4; } else { - return ParseResult::INVALID; + return invalid_ret; } } else if (2 == i) { if (' ' == b || 'T' == b) { - if (!is_valid_digits(i, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } segments[i] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i += 1; } else { - return ParseResult::INVALID; + return invalid_ret; } } else if (3 == i || 4 == i) { if (':' == b) { - if (!is_valid_digits(i, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } segments[i] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i += 1; } else { - return ParseResult::INVALID; + return invalid_ret; } } else if (5 == i || 6 == i) { if ('.' == b && 5 == i) { - if (!is_valid_digits(i, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } segments[i] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i += 1; } else { if (!is_valid_digits(i, current_segment_digits) || !allow_tz_in_date_str) { - return ParseResult::INVALID; + return invalid_ret; } segments[i] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i += 1; - *parsed_tz_ptr = bytes + j; + parsed_tz_ptr = bytes + j; // strip the whitespace between timestamp and timezone - while (*parsed_tz_ptr < end_ptr && is_whitespace(**parsed_tz_ptr)) - ++(*parsed_tz_ptr); - *parsed_tz_length = end_ptr - *parsed_tz_ptr; + while (parsed_tz_ptr < end_ptr && is_whitespace(*parsed_tz_ptr)) + ++parsed_tz_ptr; + parsed_tz_length = end_ptr - parsed_tz_ptr; break; } if (i == 6 && '.' != b) { i += 1; } } else { if (i < segments_len && (':' == b || ' ' == b)) { - if (!is_valid_digits(i, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } segments[i] = current_segment_value; current_segment_value = 0; current_segment_digits = 0; i += 1; } else { - return ParseResult::INVALID; + return invalid_ret; } } } else { @@ -521,7 +526,7 @@ struct parse_timestamp_string_fn { j += 1; } - if (!is_valid_digits(i, current_segment_digits)) { return ParseResult::INVALID; } + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } segments[i] = current_segment_value; while (digits_milli < 6) { @@ -535,15 +540,15 @@ struct parse_timestamp_string_fn { // copy segments to equivalent kernel timestamp_components // Note: In order to keep above code is equivalent to Spark implementation, // did not use `timestamp_components` directly to save values. - ts_comp->year = segments[0]; - ts_comp->month = static_cast(segments[1]); - ts_comp->day = static_cast(segments[2]); - ts_comp->hour = static_cast(segments[3]); - ts_comp->minute = static_cast(segments[4]); - ts_comp->second = static_cast(segments[5]); - ts_comp->microseconds = segments[6]; - - return ParseResult::OK; + ts_comp.year = segments[0]; + ts_comp.month = static_cast(segments[1]); + ts_comp.day = static_cast(segments[2]); + ts_comp.hour = static_cast(segments[3]); + ts_comp.minute = static_cast(segments[4]); + ts_comp.second = static_cast(segments[5]); + ts_comp.microseconds = segments[6]; + + return thrust::make_tuple(ts_comp, parsed_tz_ptr, parsed_tz_length, ParseResult::OK); } };