Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1569916: fix local testing default timestamp timezone issue #2114

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#### Bug Fixes

- Fixed a bug where the truncate mode in `DataFrameWriter.save_as_table` incorrectly handled DataFrames containing only a subset of columns from the existing table.
- Fixed a bug where function `to_timestamp` does not set the default timezone of the column datatype.

### Snowpark pandas API Updates

Expand Down Expand Up @@ -81,6 +82,7 @@

- Fixed a bug where Window Functions LEAD and LAG do not handle option `ignore_nulls` properly.
- Fixed a bug where values were not populated into the result DataFrame during the insertion of table merge operation.
- Fixed a bug where the truncate mode in `DataFrameWriter.save_as_table` incorrectly handled DataFrames containing only a subset of columns from the existing table.
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved

#### Improvements

Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from snowflake.snowpark._internal.analyzer.expression import FunctionExpression
from snowflake.snowpark.mock._options import numpy, pandas
from snowflake.snowpark.mock._snowflake_data_type import (
_TIMESTAMP_TYPE_MAPPING,
_TIMESTAMP_TYPE_TIMEZONE_MAPPING,
ColumnEmulator,
ColumnType,
TableEmulator,
Expand Down Expand Up @@ -943,7 +945,11 @@ def mock_to_timestamp(
try_cast: bool = False,
):
result = mock_to_timestamp_ntz(column, fmt, try_cast)
result.sf_type = ColumnType(TimestampType(), column.sf_type.nullable)

result.sf_type = ColumnType(
TimestampType(_TIMESTAMP_TYPE_TIMEZONE_MAPPING[_TIMESTAMP_TYPE_MAPPING]),
column.sf_type.nullable,
)
return result


Expand Down
19 changes: 19 additions & 0 deletions src/snowflake/snowpark/mock/_snowflake_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MapType,
NullType,
StringType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand All @@ -32,6 +33,17 @@
PandasDataframeType = object if not installed_pandas else pd.DataFrame
PandasSeriesType = object if not installed_pandas else pd.Series

# https://docs.snowflake.com/en/sql-reference/parameters#label-timestamp-type-mapping
# SNOW-1630258 for local testing session parameters support
_TIMESTAMP_TYPE_MAPPING = "TIMESTAMP_NTZ"


_TIMESTAMP_TYPE_TIMEZONE_MAPPING = {
"TIMESTAMP_NTZ": TimestampTimeZone.NTZ,
"TIMESTAMP_LTZ": TimestampTimeZone.LTZ,
"TIMESTAMP_TZ": TimestampTimeZone.TZ,
}


class Operator:
def op(self, *operands):
Expand Down Expand Up @@ -302,6 +314,13 @@ def coerce_t1_into_t2(t1: DataType, t2: DataType) -> Optional[DataType]:
elif isinstance(t1, (TimeType, TimestampType, MapType, ArrayType)):
if isinstance(t2, VariantType):
return t2
if isinstance(t1, TimestampType) and isinstance(t2, TimestampType):
if (
t1.tz is TimestampTimeZone.DEFAULT
and t2.tz is TimestampTimeZone.NTZ
and _TIMESTAMP_TYPE_MAPPING == "TIMESTAMP_NTZ"
):
return t2
return None


Expand Down
13 changes: 4 additions & 9 deletions tests/integ/scala/test_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,7 +1654,7 @@ def test_flatten_in_session(session):
)


def test_createDataFrame_with_given_schema(session, local_testing_mode):
def test_createDataFrame_with_given_schema(session):
schema = StructType(
[
StructField("string", StringType(84)),
Expand Down Expand Up @@ -1728,12 +1728,7 @@ def test_createDataFrame_with_given_schema(session, local_testing_mode):
StructField("number", DecimalType(10, 3)),
StructField("boolean", BooleanType()),
StructField("binary", BinaryType()),
StructField(
"timestamp",
TimestampType(TimestampTimeZone.NTZ)
if not local_testing_mode
else TimestampType(),
), # depends on TIMESTAMP_TYPE_MAPPING
StructField("timestamp", TimestampType(TimestampTimeZone.NTZ)),
StructField("timestamp_ntz", TimestampType(TimestampTimeZone.NTZ)),
StructField("timestamp_ltz", TimestampType(TimestampTimeZone.LTZ)),
StructField("timestamp_tz", TimestampType(TimestampTimeZone.TZ)),
Expand All @@ -1759,7 +1754,7 @@ def test_createDataFrame_with_given_schema_time(session):
assert df.collect() == data


def test_createDataFrame_with_given_schema_timestamp(session, local_testing_mode):
def test_createDataFrame_with_given_schema_timestamp(session):
schema = StructType(
[
StructField("timestamp", TimestampType()),
Expand All @@ -1780,7 +1775,7 @@ def test_createDataFrame_with_given_schema_timestamp(session, local_testing_mode

assert (
schema_str
== f"StructType([StructField('TIMESTAMP', TimestampType({'' if local_testing_mode else 'tz=ntz'}), nullable=True), "
== "StructType([StructField('TIMESTAMP', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_NTZ', TimestampType(tz=ntz), nullable=True), "
"StructField('TIMESTAMP_LTZ', TimestampType(tz=ltz), nullable=True), "
"StructField('TIMESTAMP_TZ', TimestampType(tz=tz), nullable=True)])"
Expand Down
29 changes: 29 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3839,6 +3839,35 @@ def test_convert_timezone(session, local_testing_mode):
],
)

df = TestData.datetime_primitives1(session).select("timestamp", "timestamp_ntz")

Utils.check_answer(
df.select(
*[
convert_timezone(lit("UTC"), col, lit("Asia/Shanghai"))
for col in df.columns
]
),
[
Row(
datetime(2024, 2, 1, 4, 0),
datetime(2017, 2, 24, 4, 0, 0, 456000),
)
],
)

df = TestData.datetime_primitives1(session).select(
"timestamp_ltz", "timestamp_tz"
)
with pytest.raises(SnowparkSQLException):
# convert_timezone function does not accept non-TimestampTimeZone.NTZ datetime
df.select(
*[
convert_timezone(lit("UTC"), col, lit("Asia/Shanghai"))
for col in df.columns
]
).collect()

LocalTimezone.set_local_timezone()


Expand Down
Loading