Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Implement msgspec encoding #2541

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
10 changes: 5 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
def mypy(session: Session) -> None:
"""Check types with mypy."""
args = session.posargs or ["singer_sdk"]
session.install(".[faker,jwt,parquet,s3,testing]")
session.install(".[faker,jwt,msgspec,parquet,s3,testing]")
session.install(*typing_dependencies)
session.run("mypy", *args)
if not session.posargs:
Expand All @@ -63,7 +63,7 @@ def mypy(session: Session) -> None:
@session(python=python_versions)
def tests(session: Session) -> None:
"""Execute pytest tests and compute coverage."""
session.install(".[faker,jwt,parquet,s3]")
session.install(".[faker,jwt,msgspec,parquet,s3]")
session.install(*test_dependencies)

sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION")
Expand Down Expand Up @@ -96,7 +96,7 @@ def tests(session: Session) -> None:
@session(python=main_python_version)
def benches(session: Session) -> None:
"""Run benchmarks."""
session.install(".[jwt,s3]")
session.install(".[jwt,msgspec,s3]")
session.install(*test_dependencies)
sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION")
if sqlalchemy_version:
Expand All @@ -116,7 +116,7 @@ def benches(session: Session) -> None:
@session(name="deps", python=python_versions)
def dependencies(session: Session) -> None:
"""Check issues with dependencies."""
session.install(".[s3,testing]")
session.install(".[msgspec,s3,testing]")
session.install("deptry")
session.run("deptry", "singer_sdk", *session.posargs)

Expand All @@ -126,7 +126,7 @@ def update_snapshots(session: Session) -> None:
"""Update pytest snapshots."""
args = session.posargs or ["-m", "snapshot"]

session.install(".[faker,jwt,parquet]")
session.install(".[faker,jwt,msgspec,parquet]")
session.install(*test_dependencies)
session.run("pytest", "--snapshot-update", *args)

Expand Down
55 changes: 54 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ inflection = ">=0.5.1"
joblib = ">=1.3.0"
jsonpath-ng = ">=1.5.3"
jsonschema = ">=4.16.0"
msgspec = { version = ">=0.18.6", optional = true }
packaging = ">=23.1"
python-dotenv = ">=0.20"
PyYAML = ">=6.0"
Expand Down Expand Up @@ -111,6 +112,7 @@ docs = [
"sphinx-notfound-page",
"sphinx-reredirects",
]
msgspec = ["msgspec"]
s3 = ["fs-s3fs"]
testing = [
"pytest",
Expand Down
3 changes: 2 additions & 1 deletion samples/sample_tap_countries/countries_tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
CountriesStream,
)
from singer_sdk import Stream, Tap
from singer_sdk._singerlib.encoding._msgspec import MsgSpecWriter # noqa: PLC2701
from singer_sdk.typing import PropertiesList


class SampleTapCountries(Tap):
class SampleTapCountries(MsgSpecWriter, Tap):
"""Sample tap for Countries GraphQL API."""

name: str = "sample-tap-countries"
Expand Down
3 changes: 2 additions & 1 deletion samples/sample_tap_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from singer_sdk import SQLConnector, SQLStream, SQLTap
from singer_sdk import typing as th
from singer_sdk._singerlib.encoding._msgspec import MsgSpecWriter # noqa: PLC2701

DB_PATH_CONFIG = "path_to_db"

Expand Down Expand Up @@ -39,7 +40,7 @@ class SQLiteStream(SQLStream):
STATE_MSG_FREQUENCY = 10


class SQLiteTap(SQLTap):
class SQLiteTap(MsgSpecWriter, SQLTap):
"""The Tap class for SQLite."""

name = "tap-sqlite-sample"
Expand Down
110 changes: 110 additions & 0 deletions singer_sdk/_singerlib/encoding/_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import datetime
import decimal
import logging
import sys
import typing as t

import msgspec

from singer_sdk._singerlib.exceptions import InvalidInputLine

from ._base import GenericSingerReader, GenericSingerWriter
from ._simple import Message

logger = logging.getLogger(__name__)


def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401
"""Encoding type helper for non native types.

Args:
obj: the item to be encoded

Returns:
The object converted to the appropriate type, default is str
"""
return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj)


def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401
"""Decoding type helper for non native types.

Args:
type: the type given
obj: the item to be decoded

Returns:
The object converted to the appropriate type, default is str.
"""
return str(obj)


encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number")
decoder = msgspec.json.Decoder(dec_hook=dec_hook, float_hook=decimal.Decimal)
_jsonl_msg_buffer = bytearray(64)


def serialize_jsonl(obj: object, **kwargs: t.Any) -> bytes: # noqa: ARG001
"""Serialize a dictionary into a line of jsonl.

Args:
obj: A Python object usually a dict.
**kwargs: Optional key word arguments.

Returns:
A bytes of serialized json.
"""
encoder.encode_into(obj, _jsonl_msg_buffer)
_jsonl_msg_buffer.extend(b"\n")
return _jsonl_msg_buffer


class MsgSpecReader(GenericSingerReader[str]):
"""Base class for all plugins reading Singer messages as strings from stdin."""

default_input = sys.stdin

def deserialize_json(self, line: str) -> dict: # noqa: PLR6301
"""Deserialize a line of json.

Args:
line: A single line of json.

Returns:
A dictionary of the deserialized json.

Raises:
InvalidInputLine: If the line cannot be parsed
"""
try:
return decoder.decode(line) # type: ignore[no-any-return]
except msgspec.DecodeError as exc:
logger.exception("Unable to parse:\n%s", line)
msg = f"Unable to parse line as JSON: {line}"
raise InvalidInputLine(msg) from exc


class MsgSpecWriter(GenericSingerWriter[bytes, Message]):
"""Interface for all plugins writing Singer messages to stdout."""

def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301
"""Serialize a dictionary into a line of json.

Args:
message: A Singer message object.

Returns:
A string of serialized json.
"""
return serialize_jsonl(message.to_dict())

def write_message(self, message: Message) -> None:
"""Write a message to stdout.

Args:
message: The message to write.
"""
sys.stdout.buffer.write(self.format_message(message))
sys.stdout.flush()
1 change: 1 addition & 0 deletions tests/_singerlib/encoding/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import annotations # noqa: INP001
Loading
Loading