Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVRO-2921: Strict Type Checking #953

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions lang/py/avro/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import struct
import sys
import zlib
from array import array
from mmap import mmap
from typing import List, Sequence, Tuple, Union

import avro.errors
import avro.io
Expand Down Expand Up @@ -66,7 +69,7 @@ class Codec(abc.ABC):
"""Abstract base class for all Avro codec classes."""

@abc.abstractmethod
def compress(self, data):
def compress(self, data: bytes) -> Tuple[bytes, int]:
"""Compress the passed data.

:param data: a byte string to be compressed
Expand All @@ -77,7 +80,7 @@ def compress(self, data):
"""

@abc.abstractmethod
def decompress(self, readers_decoder):
def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder:
"""Read compressed data via the passed BinaryDecoder and decompress it.

:param readers_decoder: a BinaryDecoder object currently being used for
Expand All @@ -91,22 +94,22 @@ def decompress(self, readers_decoder):


class NullCodec(Codec):
def compress(self, data):
def compress(self, data: bytes) -> Tuple[bytes, int]:
return data, len(data)

def decompress(self, readers_decoder):
def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder:
readers_decoder.skip_long()
return readers_decoder


class DeflateCodec(Codec):
def compress(self, data):
def compress(self, data: bytes) -> Tuple[bytes, int]:
# The first two characters and last character are zlib
# wrappers around deflate data.
compressed_data = zlib.compress(data)[2:-1]
return compressed_data, len(compressed_data)

def decompress(self, readers_decoder):
def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder:
# Compressed data is stored as (length, data), which
# corresponds to how the "bytes" type is encoded.
data = readers_decoder.read_bytes()
Expand All @@ -119,11 +122,11 @@ def decompress(self, readers_decoder):
if has_bzip2:

class BZip2Codec(Codec):
def compress(self, data):
def compress(self, data: bytes) -> Tuple[bytes, int]:
compressed_data = bz2.compress(data)
return compressed_data, len(compressed_data)

def decompress(self, readers_decoder):
def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder:
length = readers_decoder.read_long()
data = readers_decoder.read(length)
uncompressed = bz2.decompress(data)
Expand All @@ -133,13 +136,13 @@ def decompress(self, readers_decoder):
if has_snappy:

class SnappyCodec(Codec):
def compress(self, data):
def compress(self, data: bytes) -> Tuple[bytes, int]:
compressed_data = snappy.compress(data)
# A 4-byte, big-endian CRC32 checksum
compressed_data += STRUCT_CRC32.pack(binascii.crc32(data) & 0xFFFFFFFF)
return compressed_data, len(compressed_data)

def decompress(self, readers_decoder):
def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder:
# Compressed data includes a 4-byte CRC32 checksum
length = readers_decoder.read_long()
data = readers_decoder.read(length - 4)
Expand All @@ -148,20 +151,20 @@ def decompress(self, readers_decoder):
self.check_crc32(uncompressed, checksum)
return avro.io.BinaryDecoder(io.BytesIO(uncompressed))

def check_crc32(self, bytes, checksum):
def check_crc32(self, bytes_: bytes, checksum: Union[array[int], bytes, bytearray, memoryview, mmap]) -> None:
checksum = STRUCT_CRC32.unpack(checksum)[0]
if binascii.crc32(bytes) & 0xFFFFFFFF != checksum:
if binascii.crc32(bytes_) & 0xFFFFFFFF != checksum:
raise avro.errors.AvroException("Checksum failure")


if has_zstandard:

class ZstandardCodec(Codec):
def compress(self, data):
def compress(self, data: bytes) -> Tuple[bytes, int]:
compressed_data = zstd.ZstdCompressor().compress(data)
return compressed_data, len(compressed_data)

def decompress(self, readers_decoder):
def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder:
length = readers_decoder.read_long()
data = readers_decoder.read(length)
uncompressed = bytearray()
Expand All @@ -175,7 +178,7 @@ def decompress(self, readers_decoder):
return avro.io.BinaryDecoder(io.BytesIO(uncompressed))


def get_codec(codec_name):
def get_codec(codec_name: str) -> Codec:
codec_name = codec_name.lower()
if codec_name == "null":
return NullCodec()
Expand All @@ -190,7 +193,7 @@ def get_codec(codec_name):
raise avro.errors.UnsupportedCodec(f"Unsupported codec: {codec_name}. (Is it installed?)")


def supported_codec_names():
def supported_codec_names() -> List[str]:
codec_names = ["null", "deflate"]
if has_bzip2:
codec_names.append("bzip2")
Expand Down
17 changes: 7 additions & 10 deletions lang/py/avro/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# limitations under the License.
from copy import copy
from enum import Enum
from typing import List, Optional, Set, cast
from typing import Any, List, MutableMapping, Optional, Set, cast

from avro.errors import AvroRuntimeException
from avro.schema import (
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
incompatibilities: List[SchemaIncompatibilityType] = None,
messages: Optional[Set[str]] = None,
locations: Optional[Set[str]] = None,
):
) -> None:
self.locations = locations or {"/"}
self.messages = messages or set()
self.compatibility = compatibility
Expand Down Expand Up @@ -128,16 +128,15 @@ def __init__(self, reader: Schema, writer: Schema) -> None:
def __hash__(self) -> int:
return id(self.reader) ^ id(self.writer)

def __eq__(self, other) -> bool:
if not isinstance(other, ReaderWriter):
return False
return self.reader is other.reader and self.writer is other.writer
def __eq__(self, other: Any) -> bool:
return isinstance(other, ReaderWriter) and (self.reader is other.reader) and (self.writer is other.writer)


class ReaderWriterCompatibilityChecker:
ROOT_REFERENCE_TOKEN = "/"
memoize_map: MutableMapping[ReaderWriter, SchemaCompatibilityResult]

def __init__(self):
def __init__(self) -> None:
self.memoize_map = {}

def get_compatibility(
Expand Down Expand Up @@ -374,9 +373,7 @@ def incompatible(incompat_type: SchemaIncompatibilityType, message: str, locatio


def schema_name_equals(reader: NamedSchema, writer: NamedSchema) -> bool:
if reader.name == writer.name:
return True
return writer.fullname in reader.props.get("aliases", [])
return (reader.name == writer.name) or (writer.fullname in reader.props.get("aliases", []))


def lookup_writer_field(writer_schema: RecordSchema, reader_field: Field) -> Optional[Field]:
Expand Down
Loading