From eb17439aec7c40e75b0e0da1c6b8a5539ae9c70e Mon Sep 17 00:00:00 2001 From: Iwan Aucamp Date: Sat, 16 Oct 2021 16:43:04 +0200 Subject: [PATCH] Make Result.serialize work more like Graph.serialize --- rdflib/graph.py | 243 ++++++--- rdflib/parser.py | 10 +- rdflib/plugin.py | 62 ++- rdflib/plugins/serializers/jsonld.py | 11 +- rdflib/plugins/serializers/n3.py | 4 +- rdflib/plugins/serializers/nquads.py | 15 +- rdflib/plugins/serializers/nt.py | 30 +- rdflib/plugins/serializers/rdfxml.py | 46 +- rdflib/plugins/serializers/trig.py | 105 ++-- rdflib/plugins/serializers/trix.py | 11 +- rdflib/plugins/serializers/turtle.py | 90 ++-- rdflib/plugins/sparql/results/csvresults.py | 29 +- rdflib/plugins/sparql/results/jsonresults.py | 13 +- rdflib/plugins/sparql/results/txtresults.py | 28 +- rdflib/plugins/sparql/results/xmlresults.py | 20 +- rdflib/query.py | 215 ++++++-- rdflib/serializer.py | 24 +- rdflib/store.py | 23 +- rdflib/term.py | 8 +- rdflib/util.py | 32 ++ test/test_conjunctive_graph.py | 17 +- test/test_issue523.py | 5 +- test/test_serialize.py | 493 +++++++++++++++++-- test/test_sparql.py | 11 +- test/test_sparql_result_serialize.py | 281 +++++++++++ test/test_trig.py | 50 ++ test/test_util.py | 50 +- test/testutils.py | 61 ++- 28 files changed, 1651 insertions(+), 336 deletions(-) create mode 100644 test/test_sparql_result_serialize.py diff --git a/rdflib/graph.py b/rdflib/graph.py index 805bb7c64c..74f76f57ba 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -1,4 +1,16 @@ -from typing import Optional, Union, Type, cast, overload, Generator, Tuple +from typing import ( + IO, + Any, + BinaryIO, + Iterable, + Optional, + Union, + Type, + cast, + overload, + Generator, + Tuple, +) import logging from warnings import warn import random @@ -21,7 +33,7 @@ import tempfile import pathlib -from io import BytesIO, BufferedIOBase +from io import BytesIO from urllib.parse import urlparse assert Literal # avoid warning @@ -313,15 +325,20 @@ class Graph(Node): """ def __init__( - self, store="default", identifier=None, namespace_manager=None, base=None + self, + store: Union[Store, str] = "default", + identifier: Optional[Union[Node, str]] = None, + namespace_manager: Optional[NamespaceManager] = None, + base: Optional[str] = None, ): super(Graph, self).__init__() self.base = base - self.__identifier = identifier or BNode() + self.__identifier: Node + self.__identifier = identifier or BNode() # type: ignore[assignment] if not isinstance(self.__identifier, Node): - self.__identifier = URIRef(self.__identifier) - + self.__identifier = URIRef(self.__identifier) # type: ignore[unreachable] + self.__store: Store if not isinstance(store, Store): # TODO: error handling self.__store = store = plugin.get(store, Store)() @@ -332,29 +349,37 @@ def __init__( self.formula_aware = False self.default_union = False - def __get_store(self): + def __get_store(self) -> Store: return self.__store - store = property(__get_store) # read-only attr + @property + def store(self) -> Store: # read-only attr + return self.__get_store() - def __get_identifier(self): + def __get_identifier(self) -> Node: return self.__identifier - identifier = property(__get_identifier) # read-only attr + @property + def identifier(self) -> Node: # read-only attr + return self.__get_identifier() - def _get_namespace_manager(self): + def _get_namespace_manager(self) -> NamespaceManager: if self.__namespace_manager is None: self.__namespace_manager = NamespaceManager(self) return self.__namespace_manager - def _set_namespace_manager(self, nm): + def _set_namespace_manager(self, nm: NamespaceManager): self.__namespace_manager = nm - namespace_manager = property( - _get_namespace_manager, - _set_namespace_manager, - doc="this graph's namespace-manager", - ) + @property + def namespace_manager(self) -> NamespaceManager: + """this graph's namespace-manager""" + return self._get_namespace_manager() + + @namespace_manager.setter + def namespace_manager(self, value: NamespaceManager): + """this graph's namespace-manager""" + self._set_namespace_manager(value) def __repr__(self): return "" % (self.identifier, type(self)) @@ -404,7 +429,7 @@ def close(self, commit_pending_transaction=False): """ return self.__store.close(commit_pending_transaction=commit_pending_transaction) - def add(self, triple): + def add(self, triple: Tuple[Node, Node, Node]): """Add a triple with self as context""" s, p, o = triple assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,) @@ -413,7 +438,7 @@ def add(self, triple): self.__store.add((s, p, o), self, quoted=False) return self - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): """Add a sequence of triple with context""" self.__store.addN( @@ -434,7 +459,9 @@ def remove(self, triple): self.__store.remove(triple, context=self) return self - def triples(self, triple): + def triples( + self, triple: Tuple[Optional[Node], Union[None, Path, Node], Optional[Node]] + ): """Generator over the triple store Returns triples that match the given triple pattern. If triple pattern @@ -652,17 +679,17 @@ def set(self, triple): self.add((subject, predicate, object_)) return self - def subjects(self, predicate=None, object=None): + def subjects(self, predicate=None, object=None) -> Iterable[Node]: """A generator of subjects with the given predicate and object""" for s, p, o in self.triples((None, predicate, object)): yield s - def predicates(self, subject=None, object=None): + def predicates(self, subject=None, object=None) -> Iterable[Node]: """A generator of predicates with the given subject and object""" for s, p, o in self.triples((subject, None, object)): yield p - def objects(self, subject=None, predicate=None): + def objects(self, subject=None, predicate=None) -> Iterable[Node]: """A generator of objects with the given subject and predicate""" for s, p, o in self.triples((subject, predicate, None)): yield o @@ -1011,7 +1038,12 @@ def absolutize(self, uri, defrag=1): # no destination and non-None positional encoding @overload def serialize( - self, destination: None, format: str, base: Optional[str], encoding: str, **args + self, + destination: None, + format: str, + base: Optional[str], + encoding: str, + **args, ) -> bytes: ... @@ -1019,45 +1051,32 @@ def serialize( @overload def serialize( self, - *, destination: None = ..., format: str = ..., base: Optional[str] = ..., + *, encoding: str, **args, ) -> bytes: ... - # no destination and None positional encoding + # no destination and None encoding @overload def serialize( self, - destination: None, - format: str, - base: Optional[str], - encoding: None, - **args, - ) -> str: - ... - - # no destination and None keyword encoding - @overload - def serialize( - self, - *, destination: None = ..., format: str = ..., base: Optional[str] = ..., - encoding: None = None, + encoding: None = ..., **args, ) -> str: ... - # non-none destination + # non-None destination @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath], + destination: Union[str, pathlib.PurePath, IO[bytes]], format: str = ..., base: Optional[str] = ..., encoding: Optional[str] = ..., @@ -1069,34 +1088,53 @@ def serialize( @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, - format: str = "turtle", - base: Optional[str] = None, - encoding: Optional[str] = None, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = ..., + format: str = ..., + base: Optional[str] = ..., + encoding: Optional[str] = ..., **args, ) -> Union[bytes, str, "Graph"]: ... def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = None, format: str = "turtle", base: Optional[str] = None, encoding: Optional[str] = None, - **args, + **args: Any, ) -> Union[bytes, str, "Graph"]: - """Serialize the Graph to destination - - If destination is None serialize method returns the serialization as - bytes or string. - - If encoding is None and destination is None, returns a string - If encoding is set, and Destination is None, returns bytes - - Format defaults to turtle. - - Format support can be extended with plugins, - but "xml", "n3", "turtle", "nt", "pretty-xml", "trix", "trig" and "nquads" are built in. + """ + Serialize the graph. + + :param destination: + The destination to serialize the graph to. This can be a path as a + :class:`str` or :class:`~pathlib.PurePath` object, or it can be a + :class:`~typing.IO[bytes]` like object. If this parameter is not + supplied the serialized graph will be returned. + :type destination: Optional[Union[str, typing.IO[bytes], pathlib.PurePath]] + :param format: + The format that the output should be written in. This value + references a :class:`~rdflib.serializer.Serializer` plugin. Format + support can be extended with plugins, but `"xml"`, `"n3"`, + `"turtle"`, `"nt"`, `"pretty-xml"`, `"trix"`, `"trig"`, `"nquads"` + and `"json-ld"` are built in. Defaults to `"turtle"`. + :type format: str + :param base: + The base IRI for formats that support it. For the turtle format this + will be used as the `@base` directive. + :type base: Optional[str] + :param encoding: Encoding of output. + :type encoding: Optional[str] + :param **args: + Additional arguments to pass to the + :class:`~rdflib.serializer.Serializer` that will be used. + :type **args: Any + :return: The serialized graph if `destination` is `None`. + :rtype: :class:`bytes` if `destination` is `None` and `encoding` is not `None`. + :rtype: :class:`bytes` if `destination` is `None` and `encoding` is `None`. + :return: `self` (i.e. the :class:`~rdflib.graph.Graph` instance) if `destination` is not None. + :rtype: :class:`~rdflib.graph.Graph` if `destination` is not None. """ # if base is not given as attribute use the base set for the graph @@ -1104,7 +1142,7 @@ def serialize( base = self.base serializer = plugin.get(format, Serializer)(self) - stream: BufferedIOBase + stream: IO[bytes] if destination is None: stream = BytesIO() if encoding is None: @@ -1114,7 +1152,7 @@ def serialize( serializer.serialize(stream, base=base, encoding=encoding, **args) return stream.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, base=base, encoding=encoding, **args) else: if isinstance(destination, pathlib.PurePath): @@ -1149,10 +1187,10 @@ def parse( self, source=None, publicID=None, - format=None, + format: Optional[str] = None, location=None, file=None, - data=None, + data: Optional[Union[str, bytes, bytearray]] = None, **args, ): """ @@ -1293,7 +1331,7 @@ def query( if none are given, the namespaces from the graph's namespace manager are used. - :returntype: rdflib.query.Result + :returntype: :class:`~rdflib.query.Result` """ @@ -1537,7 +1575,12 @@ class ConjunctiveGraph(Graph): All queries are carried out against the union of all graphs. """ - def __init__(self, store="default", identifier=None, default_graph_base=None): + def __init__( + self, + store: Union[Store, str] = "default", + identifier: Optional[Union[Node, str]] = None, + default_graph_base: Optional[str] = None, + ): super(ConjunctiveGraph, self).__init__(store, identifier=identifier) assert self.store.context_aware, ( "ConjunctiveGraph must be backed by" " a context aware store." @@ -1555,7 +1598,31 @@ def __str__(self): ) return pattern % self.store.__class__.__name__ - def _spoc(self, triple_or_quad, default=False): + @overload + def _spoc( + self, + triple_or_quad: Union[ + Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node] + ], + default: bool = False, + ) -> Tuple[Node, Node, Node, Optional[Graph]]: + ... + + @overload + def _spoc( + self, + triple_or_quad: None, + default: bool = False, + ) -> Tuple[None, None, None, Optional[Graph]]: + ... + + def _spoc( + self, + triple_or_quad: Optional[ + Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]] + ], + default: bool = False, + ) -> Tuple[Optional[Node], Optional[Node], Optional[Node], Optional[Graph]]: """ helper method for having methods that support either triples or quads @@ -1564,9 +1631,9 @@ def _spoc(self, triple_or_quad, default=False): return (None, None, None, self.default_context if default else None) if len(triple_or_quad) == 3: c = self.default_context if default else None - (s, p, o) = triple_or_quad + (s, p, o) = triple_or_quad # type: ignore[misc] elif len(triple_or_quad) == 4: - (s, p, o, c) = triple_or_quad + (s, p, o, c) = triple_or_quad # type: ignore[misc] c = self._graph(c) return s, p, o, c @@ -1577,7 +1644,7 @@ def __contains__(self, triple_or_quad): return True return False - def add(self, triple_or_quad): + def add(self, triple_or_quad: Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]) -> "ConjunctiveGraph": # type: ignore[override] """ Add a triple or quad to the store. @@ -1591,7 +1658,15 @@ def add(self, triple_or_quad): self.store.add((s, p, o), context=c, quoted=False) return self - def _graph(self, c): + @overload + def _graph(self, c: Union[Graph, Node, str]) -> Graph: + ... + + @overload + def _graph(self, c: None) -> None: + ... + + def _graph(self, c: Optional[Union[Graph, Node, str]]) -> Optional[Graph]: if c is None: return None if not isinstance(c, Graph): @@ -1599,7 +1674,7 @@ def _graph(self, c): else: return c - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): """Add a sequence of triples with context""" self.store.addN( @@ -1689,13 +1764,19 @@ def contexts(self, triple=None): else: yield self.get_context(context) - def get_context(self, identifier, quoted=False, base=None): + def get_context( + self, + identifier: Optional[Union[Node, str]], + quoted: bool = False, + base: Optional[str] = None, + ) -> Graph: """Return a context graph for the given identifier identifier must be a URIRef or BNode. """ + # TODO: FIXME - why is ConjunctiveGraph passed as namespace_manager? return Graph( - store=self.store, identifier=identifier, namespace_manager=self, base=base + store=self.store, identifier=identifier, namespace_manager=self, base=base # type: ignore[arg-type] ) def remove_context(self, context): @@ -1747,6 +1828,7 @@ def parse( context = Graph(store=self.store, identifier=g_id) context.remove((None, None, None)) # hmm ? context.parse(source, publicID=publicID, format=format, **args) + # TODO: FIXME: This should not return context, but self. return context def __reduce__(self): @@ -1977,7 +2059,7 @@ class QuotedGraph(Graph): def __init__(self, store, identifier): super(QuotedGraph, self).__init__(store, identifier) - def add(self, triple): + def add(self, triple: Tuple[Node, Node, Node]): """Add a triple with self as context""" s, p, o = triple assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,) @@ -1987,7 +2069,7 @@ def add(self, triple): self.store.add((s, p, o), self, quoted=True) return self - def addN(self, quads): + def addN(self, quads: Tuple[Node, Node, Node, Any]) -> "QuotedGraph": # type: ignore[override] """Add a sequence of triple with context""" self.store.addN( @@ -2261,7 +2343,7 @@ class BatchAddGraph(object): """ - def __init__(self, graph, batch_size=1000, batch_addn=False): + def __init__(self, graph: Graph, batch_size: int = 1000, batch_addn: bool = False): if not batch_size or batch_size < 2: raise ValueError("batch_size must be a positive number") self.graph = graph @@ -2278,7 +2360,10 @@ def reset(self): self.count = 0 return self - def add(self, triple_or_quad): + def add( + self, + triple_or_quad: Union[Tuple[Node, Node, Node], Tuple[Node, Node, Node, Any]], + ) -> "BatchAddGraph": """ Add a triple to the buffer @@ -2294,7 +2379,7 @@ def add(self, triple_or_quad): self.batch.append(triple_or_quad) return self - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): if self.__batch_addn: for q in quads: self.add(q) diff --git a/rdflib/parser.py b/rdflib/parser.py index f0014150f6..1f8a490cde 100644 --- a/rdflib/parser.py +++ b/rdflib/parser.py @@ -16,6 +16,7 @@ import sys from io import BytesIO, TextIOBase, TextIOWrapper, StringIO, BufferedIOBase +from typing import Optional, Union from urllib.request import Request from urllib.request import url2pathname @@ -44,7 +45,7 @@ class Parser(object): def __init__(self): pass - def parse(self, source, sink): + def parse(self, source, sink, **args): pass @@ -214,7 +215,12 @@ def __repr__(self): def create_input_source( - source=None, publicID=None, location=None, file=None, data=None, format=None + source=None, + publicID=None, + location=None, + file=None, + data: Optional[Union[str, bytes, bytearray]] = None, + format=None, ): """ Return an appropriate InputSource instance for the given diff --git a/rdflib/plugin.py b/rdflib/plugin.py index 719c7eaf55..ac3a7fbd06 100644 --- a/rdflib/plugin.py +++ b/rdflib/plugin.py @@ -36,7 +36,21 @@ UpdateProcessor, ) from rdflib.exceptions import Error -from typing import Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterator, + Optional, + Tuple, + Type, + TypeVar, + overload, +) + +if TYPE_CHECKING: + from pkg_resources import EntryPoint __all__ = ["register", "get", "plugins", "PluginException", "Plugin", "PKGPlugin"] @@ -51,42 +65,47 @@ "rdf.plugins.updateprocessor": UpdateProcessor, } -_plugins = {} +_plugins: Dict[Tuple[str, Type[Any]], "Plugin"] = {} class PluginException(Error): pass -class Plugin(object): - def __init__(self, name, kind, module_path, class_name): +PluginT = TypeVar("PluginT") + + +class Plugin(Generic[PluginT]): + def __init__( + self, name: str, kind: Type[PluginT], module_path: str, class_name: str + ): self.name = name self.kind = kind self.module_path = module_path self.class_name = class_name - self._class = None + self._class: Optional[Type[PluginT]] = None - def getClass(self): + def getClass(self) -> Type[PluginT]: if self._class is None: module = __import__(self.module_path, globals(), locals(), [""]) self._class = getattr(module, self.class_name) return self._class -class PKGPlugin(Plugin): - def __init__(self, name, kind, ep): +class PKGPlugin(Plugin[PluginT]): + def __init__(self, name: str, kind: Type[PluginT], ep: "EntryPoint"): self.name = name self.kind = kind self.ep = ep - self._class = None + self._class: Optional[Type[PluginT]] = None - def getClass(self): + def getClass(self) -> Type[PluginT]: if self._class is None: self._class = self.ep.load() return self._class -def register(name: str, kind, module_path, class_name): +def register(name: str, kind: Type[Any], module_path, class_name): """ Register the plugin for (name, kind). The module_path and class_name should be the path to a plugin class. @@ -95,16 +114,13 @@ def register(name: str, kind, module_path, class_name): _plugins[(name, kind)] = p -PluginT = TypeVar("PluginT") - - def get(name: str, kind: Type[PluginT]) -> Type[PluginT]: """ Return the class for the specified (name, kind). Raises a PluginException if unable to do so. """ try: - p = _plugins[(name, kind)] + p: Plugin[PluginT] = _plugins[(name, kind)] except KeyError: raise PluginException("No plugin registered for (%s, %s)" % (name, kind)) return p.getClass() @@ -121,7 +137,21 @@ def get(name: str, kind: Type[PluginT]) -> Type[PluginT]: _plugins[(ep.name, kind)] = PKGPlugin(ep.name, kind, ep) -def plugins(name=None, kind=None): +@overload +def plugins( + name: Optional[str] = ..., kind: Type[PluginT] = ... +) -> Iterator[Plugin[PluginT]]: + ... + + +@overload +def plugins(name: Optional[str] = ..., kind: None = ...) -> Iterator[Plugin]: + ... + + +def plugins( + name: Optional[str] = None, kind: Optional[Type[PluginT]] = None +) -> Iterator[Plugin]: """ A generator of the plugins. diff --git a/rdflib/plugins/serializers/jsonld.py b/rdflib/plugins/serializers/jsonld.py index 67f3b86232..f5067e2873 100644 --- a/rdflib/plugins/serializers/jsonld.py +++ b/rdflib/plugins/serializers/jsonld.py @@ -41,6 +41,7 @@ from rdflib.graph import Graph from rdflib.term import URIRef, Literal, BNode from rdflib.namespace import RDF, XSD +from typing import IO, Optional from ..shared.jsonld.context import Context, UNDEF from ..shared.jsonld.util import json @@ -53,10 +54,16 @@ class JsonLDSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(JsonLDSerializer, self).__init__(store) - def serialize(self, stream, base=None, encoding=None, **kwargs): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **kwargs + ): # TODO: docstring w. args and return value encoding = encoding or "utf-8" if encoding not in ("utf-8", "utf-16"): diff --git a/rdflib/plugins/serializers/n3.py b/rdflib/plugins/serializers/n3.py index 6c4e2ec46d..032c779f0a 100644 --- a/rdflib/plugins/serializers/n3.py +++ b/rdflib/plugins/serializers/n3.py @@ -14,7 +14,7 @@ class N3Serializer(TurtleSerializer): short_name = "n3" - def __init__(self, store, parent=None): + def __init__(self, store: Graph, parent=None): super(N3Serializer, self).__init__(store) self.keywords.update({OWL.sameAs: "=", SWAP_LOG.implies: "=>"}) self.parent = parent @@ -109,7 +109,7 @@ def p_clause(self, node, position): self.write("{") self.depth += 1 serializer = N3Serializer(node, parent=self) - serializer.serialize(self.stream) + serializer.serialize(self.stream.buffer) self.depth -= 1 self.write(self.indent() + "}") return True diff --git a/rdflib/plugins/serializers/nquads.py b/rdflib/plugins/serializers/nquads.py index 54ee42ba12..e76c747d49 100644 --- a/rdflib/plugins/serializers/nquads.py +++ b/rdflib/plugins/serializers/nquads.py @@ -1,5 +1,7 @@ +from typing import IO, Optional import warnings +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.term import Literal from rdflib.serializer import Serializer @@ -9,15 +11,22 @@ class NQuadsSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): if not store.context_aware: raise Exception( "NQuads serialization only makes " "sense for context-aware stores!" ) super(NQuadsSerializer, self).__init__(store) - - def serialize(self, stream, base=None, encoding=None, **args): + self.store: ConjunctiveGraph + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): if base is not None: warnings.warn("NQuadsSerializer does not support base.") if encoding is not None and encoding.lower() != self.encoding.lower(): diff --git a/rdflib/plugins/serializers/nt.py b/rdflib/plugins/serializers/nt.py index bc265ee5f4..df1b9e6baa 100644 --- a/rdflib/plugins/serializers/nt.py +++ b/rdflib/plugins/serializers/nt.py @@ -3,12 +3,17 @@ See for details about the format. """ +from typing import IO, Optional + +from rdflib.graph import Graph from rdflib.term import Literal from rdflib.serializer import Serializer import warnings import codecs +from rdflib.util import as_textio + __all__ = ["NTSerializer"] @@ -17,19 +22,32 @@ class NTSerializer(Serializer): Serializes RDF graphs to NTriples format. """ - def __init__(self, store): + def __init__(self, store: Graph): Serializer.__init__(self, store) self.encoding = "ascii" # n-triples are ascii encoded - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): if base is not None: warnings.warn("NTSerializer does not support base.") if encoding is not None and encoding.lower() != self.encoding.lower(): warnings.warn("NTSerializer does not use custom encoding.") encoding = self.encoding - for triple in self.store: - stream.write(_nt_row(triple).encode(self.encoding, "_rdflib_nt_escape")) - stream.write("\n".encode("latin-1")) + + with as_textio( + stream, + encoding=self.encoding, + errors="_rdflib_nt_escape", + write_through=True, + ) as text_stream: + for triple in self.store: + text_stream.write(_nt_row(triple)) + text_stream.write("\n") class NT11Serializer(NTSerializer): @@ -39,7 +57,7 @@ class NT11Serializer(NTSerializer): Exactly like nt - only utf8 encoded. """ - def __init__(self, store): + def __init__(self, store: Graph): Serializer.__init__(self, store) # default to utf-8 diff --git a/rdflib/plugins/serializers/rdfxml.py b/rdflib/plugins/serializers/rdfxml.py index 72648afbac..fbcce0275d 100644 --- a/rdflib/plugins/serializers/rdfxml.py +++ b/rdflib/plugins/serializers/rdfxml.py @@ -1,9 +1,11 @@ +from typing import IO, Dict, Optional, Set, cast from rdflib.plugins.serializers.xmlwriter import XMLWriter from rdflib.namespace import Namespace, RDF, RDFS # , split_uri from rdflib.plugins.parsers.RDFVOC import RDFVOC -from rdflib.term import URIRef, Literal, BNode +from rdflib.graph import Graph +from rdflib.term import Identifier, URIRef, Literal, BNode from rdflib.util import first, more_than from rdflib.collection import Collection from rdflib.serializer import Serializer @@ -17,7 +19,7 @@ class XMLSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(XMLSerializer, self).__init__(store) def __bindings(self): @@ -39,14 +41,20 @@ def __bindings(self): for prefix, namespace in bindings.items(): yield prefix, namespace - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): # if base is given here, use that, if not and a base is set for the graph use that if base is not None: self.base = base elif self.store.base is not None: self.base = self.store.base self.__stream = stream - self.__serialized = {} + self.__serialized: Dict[Identifier, int] = {} encoding = self.encoding self.write = write = lambda uni: stream.write(uni.encode(encoding, "replace")) @@ -154,12 +162,20 @@ def fix(val): class PrettyXMLSerializer(Serializer): - def __init__(self, store, max_depth=3): + def __init__(self, store: Graph, max_depth=3): super(PrettyXMLSerializer, self).__init__(store) - self.forceRDFAbout = set() - - def serialize(self, stream, base=None, encoding=None, **args): - self.__serialized = {} + self.forceRDFAbout: Set[URIRef] = set() + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): + # TODO FIXME: this should be Optional, but it's not because nothing + # treats it as such. + self.__serialized: Dict[Identifier, int] = {} store = self.store # if base is given here, use that, if not and a base is set for the graph use that if base is not None: @@ -190,8 +206,9 @@ def serialize(self, stream, base=None, encoding=None, **args): writer.namespaces(namespaces.items()) + subject: Identifier # Write out subjects that can not be inline - for subject in store.subjects(): + for subject in store.subjects(): # type: ignore[assignment] if (None, None, subject) in store: if (subject, None, subject) in store: self.subject(subject, 1) @@ -202,7 +219,7 @@ def serialize(self, stream, base=None, encoding=None, **args): # write out BNodes last (to ensure they can be inlined where possible) bnodes = set() - for subject in store.subjects(): + for subject in store.subjects(): # type: ignore[assignment] if isinstance(subject, BNode): bnodes.add(subject) continue @@ -217,13 +234,14 @@ def serialize(self, stream, base=None, encoding=None, **args): stream.write("\n".encode("latin-1")) # Set to None so that the memory can get garbage collected. - self.__serialized = None + self.__serialized = None # type: ignore[assignment] - def subject(self, subject, depth=1): + def subject(self, subject: Identifier, depth: int = 1): store = self.store writer = self.writer if subject in self.forceRDFAbout: + subject = cast(URIRef, subject) writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) @@ -264,6 +282,8 @@ def subj_as_obj_more_than(ceil): writer.pop(element) elif subject in self.forceRDFAbout: + # TODO FIXME?: this looks like a duplicate of first condition + subject = cast(URIRef, subject) writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) diff --git a/rdflib/plugins/serializers/trig.py b/rdflib/plugins/serializers/trig.py index cdaedd4892..41bb6cc7ec 100644 --- a/rdflib/plugins/serializers/trig.py +++ b/rdflib/plugins/serializers/trig.py @@ -4,9 +4,12 @@ """ from collections import defaultdict +from typing import IO, Optional, Union, cast +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.plugins.serializers.turtle import TurtleSerializer -from rdflib.term import BNode +from rdflib.term import BNode, Node + __all__ = ["TrigSerializer"] @@ -16,8 +19,10 @@ class TrigSerializer(TurtleSerializer): short_name = "trig" indentString = 4 * " " - def __init__(self, store): + def __init__(self, store: Union[Graph, ConjunctiveGraph]): + self.default_context: Optional[Node] if store.context_aware: + store = cast("ConjunctiveGraph", store) self.contexts = list(store.contexts()) self.default_context = store.default_context.identifier if store.default_context: @@ -48,53 +53,53 @@ def reset(self): super(TrigSerializer, self).reset() self._contexts = {} - def serialize(self, stream, base=None, encoding=None, spacious=None, **args): - self.reset() - self.stream = stream - # if base is given here, use that, if not and a base is set for the graph use that - if base is not None: - self.base = base - elif self.store.base is not None: - self.base = self.store.base - - if spacious is not None: - self._spacious = spacious - - self.preprocess() - - self.startDocument() - - firstTime = True - for store, (ordered_subjects, subjects, ref) in self._contexts.items(): - if not ordered_subjects: - continue - - self._references = ref - self._serialized = {} - self.store = store - self._subjects = subjects - - if self.default_context and store.identifier == self.default_context: - self.write(self.indent() + "\n{") - else: - if isinstance(store.identifier, BNode): - iri = store.identifier.n3() + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + spacious: Optional[bool] = None, + **args + ): + self._serialize_init(stream, base, encoding, spacious) + try: + self.preprocess() + + self.startDocument() + + firstTime = True + for store, (ordered_subjects, subjects, ref) in self._contexts.items(): + if not ordered_subjects: + continue + + self._references = ref + self._serialized = {} + self.store = store + self._subjects = subjects + + if self.default_context and store.identifier == self.default_context: + self.write(self.indent() + "\n{") else: - iri = self.getQName(store.identifier) - if iri is None: + if isinstance(store.identifier, BNode): iri = store.identifier.n3() - self.write(self.indent() + "\n%s {" % iri) - - self.depth += 1 - for subject in ordered_subjects: - if self.isDone(subject): - continue - if firstTime: - firstTime = False - if self.statement(subject) and not firstTime: - self.write("\n") - self.depth -= 1 - self.write("}\n") - - self.endDocument() - stream.write("\n".encode("latin-1")) + else: + iri = self.getQName(store.identifier) + if iri is None: + iri = store.identifier.n3() + self.write(self.indent() + "\n%s {" % iri) + + self.depth += 1 + for subject in ordered_subjects: + if self.isDone(subject): + continue + if firstTime: + firstTime = False + if self.statement(subject) and not firstTime: + self.write("\n") + self.depth -= 1 + self.write("}\n") + + self.endDocument() + self.write("\n") + finally: + self._serialize_end() diff --git a/rdflib/plugins/serializers/trix.py b/rdflib/plugins/serializers/trix.py index 05b6f528f3..1612d815cc 100644 --- a/rdflib/plugins/serializers/trix.py +++ b/rdflib/plugins/serializers/trix.py @@ -1,3 +1,4 @@ +from typing import IO, Optional from rdflib.serializer import Serializer from rdflib.plugins.serializers.xmlwriter import XMLWriter @@ -15,14 +16,20 @@ class TriXSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(TriXSerializer, self).__init__(store) if not store.context_aware: raise Exception( "TriX serialization only makes sense for context-aware stores" ) - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): nm = self.store.namespace_manager diff --git a/rdflib/plugins/serializers/turtle.py b/rdflib/plugins/serializers/turtle.py index a62c05c421..bf0b3cfc46 100644 --- a/rdflib/plugins/serializers/turtle.py +++ b/rdflib/plugins/serializers/turtle.py @@ -6,10 +6,13 @@ from collections import defaultdict from functools import cmp_to_key +from rdflib.graph import Graph from rdflib.term import BNode, Literal, URIRef from rdflib.exceptions import Error from rdflib.serializer import Serializer from rdflib.namespace import RDF, RDFS +from io import TextIOWrapper +from typing import IO, Dict, Optional __all__ = ["RecursiveSerializer", "TurtleSerializer"] @@ -44,10 +47,13 @@ class RecursiveSerializer(Serializer): indentString = " " roundtrip_prefixes = () - def __init__(self, store): + def __init__(self, store: Graph): super(RecursiveSerializer, self).__init__(store) - self.stream = None + # TODO FIXME: Ideally stream should be optional, but nothing treats it + # as such, so least weird solution is to just type it as not optional + # even thoug it can sometimes be null. + self.stream: IO[str] = None # type: ignore[assignment] self.reset() def addNamespace(self, prefix, uri): @@ -166,9 +172,9 @@ def indent(self, modifier=0): """Returns indent string multiplied by the depth""" return (self.depth + modifier) * self.indentString - def write(self, text): - """Write text in given encoding.""" - self.stream.write(text.encode(self.encoding, "replace")) + def write(self, text: str): + """Write text""" + self.stream.write(text) SUBJECT = 0 @@ -184,15 +190,15 @@ class TurtleSerializer(RecursiveSerializer): short_name = "turtle" indentString = " " - def __init__(self, store): - self._ns_rewrite = {} + def __init__(self, store: Graph): + self._ns_rewrite: Dict[str, str] = {} super(TurtleSerializer, self).__init__(store) self.keywords = {RDF.type: "a"} self.reset() - self.stream = None + self.stream: TextIOWrapper = None # type: ignore[assignment] self._spacious = _SPACIOUS_OUTPUT - def addNamespace(self, prefix, namespace): + def addNamespace(self, prefix: str, namespace: str): # Turtle does not support prefix that start with _ # if they occur in the graph, rewrite to p_blah # this is more complicated since we need to make sure p_blah @@ -223,36 +229,60 @@ def reset(self): self._started = False self._ns_rewrite = {} - def serialize(self, stream, base=None, encoding=None, spacious=None, **args): + def _serialize_init( + self, + stream: IO[bytes], + base: Optional[str], + encoding: Optional[str], + spacious: Optional[bool], + ) -> None: self.reset() - self.stream = stream + if encoding is not None: + self.encoding = encoding + self.stream = TextIOWrapper( + stream, self.encoding, errors="replace", write_through=True + ) # if base is given here, use that, if not and a base is set for the graph use that if base is not None: self.base = base elif self.store.base is not None: self.base = self.store.base - if spacious is not None: self._spacious = spacious - self.preprocess() - subjects_list = self.orderSubjects() - - self.startDocument() - - firstTime = True - for subject in subjects_list: - if self.isDone(subject): - continue - if firstTime: - firstTime = False - if self.statement(subject) and not firstTime: - self.write("\n") - - self.endDocument() - stream.write("\n".encode("latin-1")) - - self.base = None + def _serialize_end(self) -> None: + self.stream.flush() + self.stream.detach() + self.stream = None # type: ignore[assignment] + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + spacious: Optional[bool] = None, + **args, + ): + self._serialize_init(stream, base, encoding, spacious) + try: + self.preprocess() + subjects_list = self.orderSubjects() + + self.startDocument() + + firstTime = True + for subject in subjects_list: + if self.isDone(subject): + continue + if firstTime: + firstTime = False + if self.statement(subject) and not firstTime: + self.write("\n") + + self.endDocument() + self.stream.write("\n") + finally: + self._serialize_end() def preprocessTriple(self, triple): super(TurtleSerializer, self).preprocessTriple(triple) diff --git a/rdflib/plugins/sparql/results/csvresults.py b/rdflib/plugins/sparql/results/csvresults.py index c87b6ea760..aba7ac058b 100644 --- a/rdflib/plugins/sparql/results/csvresults.py +++ b/rdflib/plugins/sparql/results/csvresults.py @@ -9,11 +9,14 @@ import codecs import csv +from typing import IO, TYPE_CHECKING, Optional, TextIO, Union from rdflib import Variable, BNode, URIRef, Literal from rdflib.query import Result, ResultSerializer, ResultParser +from rdflib.util import as_textio + class CSVResultParser(ResultParser): def __init__(self): @@ -61,24 +64,24 @@ def __init__(self, result): if result.type != "SELECT": raise Exception("CSVSerializer can only serialize select query results") - def serialize(self, stream, encoding="utf-8", **kwargs): + def serialize( + self, stream: Union[IO[bytes], TextIO], encoding: Optional[str] = None, **kwargs + ): # the serialiser writes bytes in the given encoding # in py3 csv.writer is unicode aware and writes STRINGS, # so we encode afterwards - import codecs - - stream = codecs.getwriter(encoding)(stream) - - out = csv.writer(stream, delimiter=self.delim) - - vs = [self.serializeTerm(v, encoding) for v in self.result.vars] - out.writerow(vs) - for row in self.result.bindings: - out.writerow( - [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] - ) + with as_textio(stream, encoding=encoding) as stream: + out = csv.writer(stream, delimiter=self.delim) + if TYPE_CHECKING: + assert self.result.vars is not None + vs = [self.serializeTerm(v, encoding) for v in self.result.vars] + out.writerow(vs) + for row in self.result.bindings: + out.writerow( + [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] + ) def serializeTerm(self, term, encoding): if term is None: diff --git a/rdflib/plugins/sparql/results/jsonresults.py b/rdflib/plugins/sparql/results/jsonresults.py index 13a8da5eff..5e933c1c77 100644 --- a/rdflib/plugins/sparql/results/jsonresults.py +++ b/rdflib/plugins/sparql/results/jsonresults.py @@ -1,7 +1,9 @@ import json +from typing import IO, Any, Dict, Optional, TextIO, Union from rdflib.query import Result, ResultException, ResultSerializer, ResultParser from rdflib import Literal, URIRef, BNode, Variable +from rdflib.util import as_textio """A Serializer for SPARQL results in JSON: @@ -28,9 +30,10 @@ class JSONResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream, encoding=None): - - res = {} + def serialize( + self, stream: Union[IO[bytes], TextIO], encoding: Optional[str] = None, **kwargs + ): + res: Dict[str, Any] = {} if self.result.type == "ASK": res["head"] = {} res["boolean"] = self.result.askAnswer @@ -44,9 +47,7 @@ def serialize(self, stream, encoding=None): ] r = json.dumps(res, allow_nan=False, ensure_ascii=False) - if encoding is not None: - stream.write(r.encode(encoding)) - else: + with as_textio(stream, encoding=encoding) as stream: stream.write(r) def _bindingToJSON(self, b): diff --git a/rdflib/plugins/sparql/results/txtresults.py b/rdflib/plugins/sparql/results/txtresults.py index baa5316b48..c3efed967b 100644 --- a/rdflib/plugins/sparql/results/txtresults.py +++ b/rdflib/plugins/sparql/results/txtresults.py @@ -1,8 +1,11 @@ +from typing import IO, TYPE_CHECKING, Optional, TextIO, Union from rdflib import URIRef, BNode, Literal from rdflib.query import ResultSerializer +from rdflib.namespace import NamespaceManager +from rdflib.util import as_textio -def _termString(t, namespace_manager): +def _termString(t, namespace_manager: Optional[NamespaceManager]): if t is None: return "-" if namespace_manager: @@ -21,7 +24,13 @@ class TXTResultSerializer(ResultSerializer): A write only QueryResult serializer for text/ascii tables """ - def serialize(self, stream, encoding, namespace_manager=None): + # TODO FIXME: class specific args should be keyword only. + def serialize( # type: ignore[override] + self, + stream: Union[IO[bytes], TextIO], + encoding: Optional[str], + namespace_manager: Optional[NamespaceManager] = None, + ): """ return a text table of query results """ @@ -42,7 +51,8 @@ def c(s, w): if not self.result: return "(no results)\n" else: - + if TYPE_CHECKING: + assert self.result.vars is not None keys = self.result.vars maxlen = [0] * len(keys) b = [ @@ -53,9 +63,13 @@ def c(s, w): for i in range(len(keys)): maxlen[i] = max(maxlen[i], len(r[i])) - stream.write("|".join([c(k, maxlen[i]) for i, k in enumerate(keys)]) + "\n") - stream.write("-" * (len(maxlen) + sum(maxlen)) + "\n") - for r in sorted(b): + with as_textio(stream) as stream: stream.write( - "|".join([t + " " * (i - len(t)) for i, t in zip(maxlen, r)]) + "\n" + "|".join([c(k, maxlen[i]) for i, k in enumerate(keys)]) + "\n" ) + stream.write("-" * (len(maxlen) + sum(maxlen)) + "\n") + for r in sorted(b): + stream.write( + "|".join([t + " " * (i - len(t)) for i, t in zip(maxlen, r)]) + + "\n" + ) diff --git a/rdflib/plugins/sparql/results/xmlresults.py b/rdflib/plugins/sparql/results/xmlresults.py index 8c77b50ad1..a7ac2df139 100644 --- a/rdflib/plugins/sparql/results/xmlresults.py +++ b/rdflib/plugins/sparql/results/xmlresults.py @@ -1,4 +1,5 @@ import logging +from typing import IO, Optional, TextIO, Union from xml.sax.saxutils import XMLGenerator from xml.dom import XML_NAMESPACE @@ -28,15 +29,17 @@ class XMLResultParser(ResultParser): - def parse(self, source, content_type=None): + # TODO FIXME: content_type should be a keyword only arg. + def parse(self, source, content_type: Optional[str] = None): # type: ignore[override] return XMLResult(source) class XMLResult(Result): - def __init__(self, source, content_type=None): + def __init__(self, source, content_type: Optional[str] = None): try: - parser = etree.XMLParser(huge_tree=True) + # try use as if etree is from lxml, and if not use it as normal. + parser = etree.XMLParser(huge_tree=True) # type: ignore[call-arg] tree = etree.parse(source, parser) except TypeError: tree = etree.parse(source) @@ -55,7 +58,7 @@ def __init__(self, source, content_type=None): if type_ == "SELECT": self.bindings = [] - for result in results: + for result in results: # type: ignore[union-attr] r = {} for binding in result: r[Variable(binding.get("name"))] = parseTerm(binding[0]) @@ -69,7 +72,7 @@ def __init__(self, source, content_type=None): ] else: - self.askAnswer = boolean.text.lower().strip() == "true" + self.askAnswer = boolean.text.lower().strip() == "true" # type: ignore[union-attr] def parseTerm(element): @@ -101,8 +104,11 @@ class XMLResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream, encoding="utf-8"): - + def serialize( + self, stream: Union[IO[bytes], TextIO], encoding: Optional[str] = None, **kwargs + ): + if encoding is None: + encoding = "utf-8" writer = SPARQLXMLWriter(stream, encoding) if self.result.type == "ASK": writer.write_header([]) diff --git a/rdflib/query.py b/rdflib/query.py index da174cd1f2..b3649787c3 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -4,14 +4,19 @@ import tempfile import warnings import types -from typing import Optional, Union, cast +import pathlib +from typing import IO, TYPE_CHECKING, List, Optional, TextIO, Union, cast, overload -from io import BytesIO, BufferedIOBase +from io import BytesIO from urllib.parse import urlparse __all__ = ["Processor", "Result", "ResultParser", "ResultSerializer", "ResultException"] +if TYPE_CHECKING: + from .graph import Graph + from .term import Variable + class Processor(object): """ @@ -161,17 +166,17 @@ class Result(object): """ - def __init__(self, type_): + def __init__(self, type_: str): if type_ not in ("CONSTRUCT", "DESCRIBE", "SELECT", "ASK"): raise ResultException("Unknown Result type: %s" % type_) self.type = type_ - self.vars = None + self.vars: Optional[List[Variable]] = None self._bindings = None self._genbindings = None - self.askAnswer = None - self.graph = None + self.askAnswer: bool = None # type: ignore[assignment] + self.graph: "Graph" = None # type: ignore[assignment] def _get_bindings(self): if self._genbindings: @@ -192,7 +197,12 @@ def _set_bindings(self, b): ) @staticmethod - def parse(source=None, format=None, content_type=None, **kwargs): + def parse( + source=None, + format: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs, + ): from rdflib import plugin if format: @@ -206,54 +216,176 @@ def parse(source=None, format=None, content_type=None, **kwargs): return parser.parse(source, content_type=content_type, **kwargs) + # None destination and non-None positional encoding + @overload def serialize( self, - destination: Optional[Union[str, BufferedIOBase]] = None, - encoding: str = "utf-8", - format: str = "xml", + destination: None, + encoding: str, + format: Optional[str] = ..., **args, - ) -> Optional[bytes]: - """ - Serialize the query result. + ) -> bytes: + ... + + # None destination and non-None keyword encoding + @overload + def serialize( + self, + *, + destination: None = ..., + encoding: str, + format: Optional[str] = ..., + **args, + ) -> bytes: + ... + + # None destination and None positional encoding + @overload + def serialize( + self, + destination: None, + encoding: None = None, + format: Optional[str] = ..., + **args, + ) -> str: + ... + + # None destination and None keyword encoding + @overload + def serialize( + self, + *, + destination: None = ..., + encoding: None = None, + format: Optional[str] = ..., + **args, + ) -> str: + ... + + # non-none binary destination + @overload + def serialize( + self, + destination: Union[str, pathlib.PurePath, IO[bytes]], + encoding: Optional[str] = ..., + format: Optional[str] = ..., + **args, + ) -> None: + ... + + # non-none text destination + @overload + def serialize( + self, + destination: Union[TextIO], + encoding: None = ..., + format: Optional[str] = ..., + **args, + ) -> None: + ... - The :code:`format` argument determines the Serializer class to use. + # fallback + @overload + def serialize( + self, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes], TextIO]] = ..., + encoding: Optional[str] = ..., + format: Optional[str] = ..., + **args, + ) -> Union[bytes, str, None]: + ... - - csv: :class:`~rdflib.plugins.sparql.results.csvresults.CSVResultSerializer` - - json: :class:`~rdflib.plugins.sparql.results.jsonresults.JSONResultSerializer` - - txt: :class:`~rdflib.plugins.sparql.results.txtresults.TXTResultSerializer` - - xml: :class:`~rdflib.plugins.sparql.results.xmlresults.XMLResultSerializer` + # NOTE: Using TextIO as opposed to IO[str] because I want to be able to use buffer. + def serialize( + self, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes], TextIO]] = None, + encoding: Optional[str] = None, + format: Optional[str] = None, + **args, + ) -> Union[bytes, str, None]: + """ + Serialize the query result. - :param destination: Path of file output or BufferedIOBase object to write the output to. + :param destination: + The destination to serialize the result to. This can be a path as a + :class:`str` or :class:`~pathlib.PurePath` object, or it can be a + :class:`~typing.IO[bytes]` or :class:`~typing.TextIO` like object. If this parameter is not + supplied the serialized result will be returned. + :type destination: Optional[Union[str, typing.IO[bytes], pathlib.PurePath]] :param encoding: Encoding of output. - :param format: One of ['csv', 'json', 'txt', xml'] - :param args: - :return: bytes + :type encoding: Optional[str] + :param format: + The format that the output should be written in. + + For tabular results, the value refers to a + :class:`rdflib.query.ResultSerializer` plugin.Support for the + following tabular formats are built in: + + - `"csv"`: :class:`~rdflib.plugins.sparql.results.csvresults.CSVResultSerializer` + - `"json"`: :class:`~rdflib.plugins.sparql.results.jsonresults.JSONResultSerializer` + - `"txt"`: :class:`~rdflib.plugins.sparql.results.txtresults.TXTResultSerializer` + - `"xml"`: :class:`~rdflib.plugins.sparql.results.xmlresults.XMLResultSerializer` + + For tabular results, the default format is `"txt"`. + + For graph results, the value refers to a + :class:`~rdflib.serializer.Serializer` plugin and is passed to + :func:`~rdflib.graph.Graph.serialize`. Graph format support can be + extended with plugins, but support for `"xml"`, `"n3"`, `"turtle"`, `"nt"`, + `"pretty-xml"`, `"trix"`, `"trig"`, `"nquads"` and `"json-ld"` are + built in. The default graph format is `"turtle"`. + :type format: str """ if self.type in ("CONSTRUCT", "DESCRIBE"): - return self.graph.serialize( - destination, encoding=encoding, format=format, **args + if format is None: + format = "turtle" + if ( + destination is not None + and hasattr(destination, "encoding") + and hasattr(destination, "buffer") + ): + # rudimentary check for TextIO-like objects. + destination = cast(TextIO, destination).buffer + destination = cast( + Optional[Union[str, pathlib.PurePath, IO[bytes]]], destination ) + result = self.graph.serialize( + destination=destination, format=format, encoding=encoding, **args + ) + from rdflib.graph import Graph + + if isinstance(result, Graph): + return None + return result """stolen wholesale from graph.serialize""" from rdflib import plugin + if format is None: + format = "txt" serializer = plugin.get(format, ResultSerializer)(self) + stream: IO[bytes] if destination is None: - streamb: BytesIO = BytesIO() - stream2 = EncodeOnlyUnicode(streamb) - serializer.serialize(stream2, encoding=encoding, **args) - return streamb.getvalue() + stream = BytesIO() + if encoding is None: + serializer.serialize(stream, encoding="utf-8", **args) + return stream.getvalue().decode("utf-8") + else: + serializer.serialize(stream, encoding=encoding, **args) + return stream.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, encoding=encoding, **args) else: - location = cast(str, destination) + if isinstance(destination, pathlib.PurePath): + location = str(destination) + else: + location = cast(str, destination) scheme, netloc, path, params, query, fragment = urlparse(location) if netloc != "": - print( - "WARNING: not saving as location" + "is not a local file reference" + raise ValueError( + f"destination {destination} is not a local file reference" ) - return None fd, name = tempfile.mkstemp() stream = os.fdopen(fd, "wb") serializer.serialize(stream, encoding=encoding, **args) @@ -339,9 +471,22 @@ def parse(self, source, **kwargs): class ResultSerializer(object): - def __init__(self, result): + def __init__(self, result: Result): self.result = result - def serialize(self, stream, encoding="utf-8", **kwargs): + @overload + def serialize(self, stream: IO[bytes], encoding: Optional[str] = ..., **kwargs): + ... + + @overload + def serialize(self, stream: TextIO, encoding: None = ..., **kwargs): + ... + + def serialize( + self, + stream: Union[IO[bytes], TextIO], + encoding: Optional[str] = None, + **kwargs, + ): """return a string properly serialized""" pass # abstract diff --git a/rdflib/serializer.py b/rdflib/serializer.py index ecb8da0a2b..74f29544bc 100644 --- a/rdflib/serializer.py +++ b/rdflib/serializer.py @@ -10,21 +10,31 @@ """ +from typing import IO, TYPE_CHECKING, Optional from rdflib.term import URIRef +if TYPE_CHECKING: + from rdflib.graph import Graph + __all__ = ["Serializer"] -class Serializer(object): - def __init__(self, store): - self.store = store - self.encoding = "UTF-8" - self.base = None +class Serializer: + def __init__(self, store: "Graph"): + self.store: "Graph" = store + self.encoding: str = "utf-8" + self.base: Optional[str] = None - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ) -> None: """Abstract method""" - def relativize(self, uri): + def relativize(self, uri: str): base = self.base if base is not None and uri.startswith(base): uri = URIRef(uri.replace(base, "", 1)) diff --git a/rdflib/store.py b/rdflib/store.py index a7aa8d0b09..bb3e8bdc45 100644 --- a/rdflib/store.py +++ b/rdflib/store.py @@ -1,6 +1,12 @@ from io import BytesIO import pickle +from typing_extensions import TYPE_CHECKING from rdflib.events import Dispatcher, Event +from typing import Tuple, Iterable, Optional + +if TYPE_CHECKING: + from .term import Node + from .graph import Graph """ ============ @@ -172,7 +178,7 @@ def __get_node_pickler(self): def create(self, configuration): self.dispatcher.dispatch(StoreCreatedEvent(configuration=configuration)) - def open(self, configuration, create=False): + def open(self, configuration, create: bool = False): """ Opens the store specified by the configuration string. If create is True a store will be created if it does not already @@ -204,7 +210,12 @@ def gc(self): pass # RDF APIs - def add(self, triple, context, quoted=False): + def add( + self, + triple: Tuple["Node", "Node", "Node"], + context: Optional["Graph"], + quoted=False, + ): """ Adds the given statement to a specific context or to the model. The quoted argument is interpreted by formula-aware stores to indicate @@ -215,7 +226,7 @@ def add(self, triple, context, quoted=False): """ self.dispatcher.dispatch(TripleAddedEvent(triple=triple, context=context)) - def addN(self, quads): + def addN(self, quads: Iterable[Tuple["Node", "Node", "Node", "Graph"]]): """ Adds each item in the list of statements to a specific context. The quoted argument is interpreted by formula-aware stores to indicate this @@ -283,7 +294,11 @@ def triples_choices(self, triple, context=None): for (s1, p1, o1), cg in self.triples((subject, None, object_), context): yield (s1, p1, o1), cg - def triples(self, triple_pattern, context=None): + def triples( + self, + triple_pattern: Tuple[Optional["Node"], Optional["Node"], Optional["Node"]], + context=None, + ): """ A generator over all the triples matching the pattern. Pattern can include any objects for used for comparing against nodes in the store, diff --git a/rdflib/term.py b/rdflib/term.py index eb1f2cb6ca..837e4a726e 100644 --- a/rdflib/term.py +++ b/rdflib/term.py @@ -64,7 +64,7 @@ from urllib.parse import urlparse from decimal import Decimal -from typing import TYPE_CHECKING, Dict, Callable, Union, Type +from typing import TYPE_CHECKING, Dict, Callable, Optional, Union, Type if TYPE_CHECKING: from .paths import AlternativePath, InvPath, NegatedPath, SequencePath, Path @@ -231,10 +231,10 @@ class URIRef(Identifier): __neg__: Callable[["URIRef"], "NegatedPath"] __truediv__: Callable[["URIRef", Union["URIRef", "Path"]], "SequencePath"] - def __new__(cls, value, base=None): + def __new__(cls, value: str, base: Optional[str] = None): if base is not None: ends_in_hash = value.endswith("#") - value = urljoin(base, value, allow_fragments=1) + value = urljoin(base, value, allow_fragments=True) if ends_in_hash: if not value.endswith("#"): value += "#" @@ -248,7 +248,7 @@ def __new__(cls, value, base=None): try: rt = str.__new__(cls, value) except UnicodeDecodeError: - rt = str.__new__(cls, value, "utf-8") + rt = str.__new__(cls, value, "utf-8") # type: ignore[call-overload] return rt def toPython(self): diff --git a/rdflib/util.py b/rdflib/util.py index 28ca083625..c7e32eab73 100644 --- a/rdflib/util.py +++ b/rdflib/util.py @@ -37,8 +37,18 @@ from time import localtime from time import time from time import timezone +from io import TextIOWrapper from os.path import splitext +from typing import ( + IO, + Generator, + Optional, + TextIO, + Union, + cast, + overload, +) from rdflib.exceptions import ContextTypeError from rdflib.exceptions import ObjectTypeError @@ -51,6 +61,7 @@ from rdflib.term import Literal from rdflib.term import URIRef from rdflib.compat import sign +from contextlib import contextmanager __all__ = [ "list2set", @@ -492,6 +503,27 @@ def get_tree( return (mapper(root), sorted(tree, key=sortkey)) +@contextmanager +def as_textio( + anyio: Union[IO[bytes], TextIO], + encoding: Optional[str] = None, + errors: Union[str, None] = None, + write_through: bool = False, +) -> Generator[TextIO, None, None]: + if hasattr(anyio, "encoding"): + yield cast(TextIO, anyio) + else: + textio_wrapper = TextIOWrapper( + cast(IO[bytes], anyio), + encoding=encoding, + errors=errors, + write_through=write_through, + ) + yield textio_wrapper + textio_wrapper.flush() + textio_wrapper.detach() + + def test(): import doctest diff --git a/test/test_conjunctive_graph.py b/test/test_conjunctive_graph.py index ed775c4af8..5f326b749a 100644 --- a/test/test_conjunctive_graph.py +++ b/test/test_conjunctive_graph.py @@ -2,10 +2,13 @@ Tests for ConjunctiveGraph that do not depend on the underlying store """ +import unittest from rdflib import ConjunctiveGraph, Graph +from rdflib.namespace import Namespace from rdflib.term import Identifier, URIRef, BNode from rdflib.parser import StringInputSource -from os import path + +from .testutils import GraphHelper DATA = """ @@ -14,6 +17,18 @@ PUBLIC_ID = "http://example.org/record/1" +EG = Namespace("http://example.com/") + + +class TestConjuctiveGraph(unittest.TestCase): + def test_add(self) -> None: + quad = (EG["subject"], EG["predicate"], EG["object"], EG["graph"]) + g = ConjunctiveGraph() + g.add(quad) + quad_set = GraphHelper.quad_set(g) + self.assertEqual(len(quad_set), 1) + self.assertEqual(next(iter(quad_set)), quad) + def test_bnode_publicid(): diff --git a/test/test_issue523.py b/test/test_issue523.py index 2910cdd71a..54210171e3 100644 --- a/test/test_issue523.py +++ b/test/test_issue523.py @@ -9,9 +9,10 @@ def test_issue523(): "SELECT (<../baz> as ?test) WHERE {}", base=rdflib.URIRef("http://example.org/foo/bar"), ) - res = r.serialize(format="csv") + res = r.serialize(format="csv", encoding="utf-8") assert res == b"test\r\nhttp://example.org/baz\r\n", repr(res) - + res = r.serialize(format="csv") + assert res == "test\r\nhttp://example.org/baz\r\n", repr(res) # expected result: # test # http://example.org/baz diff --git a/test/test_serialize.py b/test/test_serialize.py index 90fe14df43..85f83462e4 100644 --- a/test/test_serialize.py +++ b/test/test_serialize.py @@ -1,59 +1,474 @@ +import enum +import inspect +import itertools +import sys import unittest -from rdflib import Graph, URIRef -from tempfile import NamedTemporaryFile, TemporaryDirectory +from contextlib import ExitStack +from io import IOBase from pathlib import Path, PurePath +from tempfile import TemporaryDirectory +from test.testutils import GraphHelper, get_unique_plugins +from typing import ( + IO, + Any, + Dict, + Iterable, + NamedTuple, + Optional, + Set, + TextIO, + Tuple, + Union, + cast, +) + +from rdflib import Graph +from rdflib.graph import ConjunctiveGraph +from rdflib.namespace import Namespace +from rdflib.plugin import PluginException +from rdflib.serializer import Serializer + +EG = Namespace("http://example.com/") + + +class DestinationType(str, enum.Enum): + PATH = enum.auto() + PURE_PATH = enum.auto() + PATH_STR = enum.auto() + IO_BYTES = enum.auto() + TEXT_IO = enum.auto() + + +class DestinationFactory: + _counter: int = 0 + + def __init__(self, tmpdir: Path) -> None: + self.tmpdir = tmpdir + + def make( + self, + type: DestinationType, + stack: Optional[ExitStack] = None, + ) -> Tuple[Union[str, Path, PurePath, IO[bytes], TextIO], Path]: + self._counter += 1 + count = self._counter + path = self.tmpdir / f"file-{type}-{count:05d}" + if type is DestinationType.PATH: + return (path, path) + if type is DestinationType.PURE_PATH: + return (PurePath(path), path) + if type is DestinationType.PATH_STR: + return (f"{path}", path) + if type is DestinationType.IO_BYTES: + return ( + path.open("wb") + if stack is None + else stack.enter_context(path.open("wb")), + path, + ) + if type is DestinationType.TEXT_IO: + return ( + path.open("w") + if stack is None + else stack.enter_context(path.open("w")), + path, + ) + raise ValueError(f"unsupported type {type}") + + +class GraphType(str, enum.Enum): + QUAD = enum.auto() + TRIPLE = enum.auto() + + +class FormatInfo(NamedTuple): + serializer_name: str + deserializer_name: str + graph_types: Set[GraphType] + encodings: Set[str] + + +class FormatInfos(Dict[str, FormatInfo]): + def add_format( + self, + serializer_name: str, + *, + deserializer_name: Optional[str] = None, + graph_types: Set[GraphType], + encodings: Set[str], + ) -> None: + self[serializer_name] = FormatInfo( + serializer_name, + serializer_name if deserializer_name is None else deserializer_name, + {GraphType.QUAD, GraphType.TRIPLE} if graph_types is None else graph_types, + encodings, + ) + + def select( + self, + *, + name: Optional[Set[str]] = None, + graph_type: Optional[Set[GraphType]] = None, + ) -> Iterable[FormatInfo]: + for format in self.values(): + if graph_type is not None and not graph_type.isdisjoint(format.graph_types): + yield format + if name is not None and format.serializer_name in name: + yield format + + @classmethod + def make_graph(self, format_info: FormatInfo) -> Graph: + if GraphType.QUAD in format_info.graph_types: + return ConjunctiveGraph() + else: + return Graph() + + @classmethod + def make(cls) -> "FormatInfos": + result = cls() + + flexible_formats = { + "trig", + } + for format in flexible_formats: + result.add_format( + format, + graph_types={GraphType.TRIPLE, GraphType.QUAD}, + encodings={"utf-8"}, + ) + + triple_only_formats = { + "turtle", + "nt11", + "xml", + "n3", + } + for format in triple_only_formats: + result.add_format( + format, graph_types={GraphType.TRIPLE}, encodings={"utf-8"} + ) + + quad_only_formats = { + "nquads", + "trix", + "json-ld", + } + for format in quad_only_formats: + result.add_format(format, graph_types={GraphType.QUAD}, encodings={"utf-8"}) + + result.add_format( + "pretty-xml", + deserializer_name="xml", + graph_types={GraphType.TRIPLE}, + encodings={"utf-8"}, + ) + result.add_format( + "ntriples", + graph_types={GraphType.TRIPLE}, + encodings={"ascii"}, + ) + + return result + + +format_infos = FormatInfos.make() + + +def assert_graphs_equal( + test_case: unittest.TestCase, lhs: Graph, rhs: Graph, check_context: bool = True +) -> None: + lhs_has_quads = hasattr(lhs, "quads") + rhs_has_quads = hasattr(rhs, "quads") + lhs_set: Set[Any] + rhs_set: Set[Any] + if lhs_has_quads and rhs_has_quads and check_context: + lhs = cast(ConjunctiveGraph, lhs) + rhs = cast(ConjunctiveGraph, rhs) + lhs_set, rhs_set = GraphHelper.quad_sets([lhs, rhs]) + else: + lhs_set, rhs_set = GraphHelper.triple_sets([lhs, rhs]) + test_case.assertEqual(lhs_set, rhs_set) + test_case.assertTrue(len(lhs_set) > 0) + test_case.assertTrue(len(rhs_set) > 0) class TestSerialize(unittest.TestCase): def setUp(self) -> None: - - graph = Graph() - subject = URIRef("example:subject") - predicate = URIRef("example:predicate") - object = URIRef("example:object") self.triple = ( - subject, - predicate, - object, + EG["subject"], + EG["predicate"], + EG["object"], ) - graph.add(self.triple) - self.graph = graph - return super().setUp() + self.context = EG["graph"] + self.quad = (*self.triple, self.context) - def test_serialize_to_purepath(self): - with TemporaryDirectory() as td: - tfpath = PurePath(td) / "out.nt" - self.graph.serialize(destination=tfpath, format="nt") - graph_check = Graph() - graph_check.parse(source=tfpath, format="nt") + conjunctive_graph = ConjunctiveGraph() + conjunctive_graph.add(self.quad) + self.graph = conjunctive_graph - self.assertEqual(self.triple, next(iter(graph_check))) + query = """ + CONSTRUCT { ?subject ?predicate ?object } WHERE { + ?subject ?predicate ?object + } ORDER BY ?object + """ + self.result = self.graph.query(query) + self.assertIsNotNone(self.result.graph) - def test_serialize_to_path(self): - with NamedTemporaryFile() as tf: - tfpath = Path(tf.name) - self.graph.serialize(destination=tfpath, format="nt") - graph_check = Graph() - graph_check.parse(source=tfpath, format="nt") + self._tmpdir = TemporaryDirectory() + self.tmpdir = Path(self._tmpdir.name) + + return super().setUp() - self.assertEqual(self.triple, next(iter(graph_check))) + def tearDown(self) -> None: + self._tmpdir.cleanup() - def test_serialize_to_neturl(self): + def test_graph(self) -> None: + quad_set = GraphHelper.quad_set(self.graph) + self.assertEqual(quad_set, {self.quad}) + + def test_all_formats_specified(self) -> None: + plugins = get_unique_plugins(Serializer) + for plugin_refs in plugins.values(): + names = {plugin_ref.name for plugin_ref in plugin_refs} + self.assertNotEqual( + names.intersection(format_infos.keys()), + set(), + f"serializers does not include any of {names}", + ) + + def test_serialize_to_neturl(self) -> None: with self.assertRaises(ValueError) as raised: self.graph.serialize(destination="http://example.com/", format="nt") self.assertIn("destination", f"{raised.exception}") - def test_serialize_to_fileurl(self): - with TemporaryDirectory() as td: - tfpath = Path(td) / "out.nt" - tfurl = tfpath.as_uri() - self.assertRegex(tfurl, r"^file:") - self.assertFalse(tfpath.exists()) - self.graph.serialize(destination=tfurl, format="nt") - self.assertTrue(tfpath.exists()) - graph_check = Graph() - graph_check.parse(source=tfpath, format="nt") - self.assertEqual(self.triple, next(iter(graph_check))) + def test_serialize_badformat(self) -> None: + with self.assertRaises(PluginException) as raised: + self.graph.serialize(destination="http://example.com/", format="badformat") + self.assertIn("badformat", f"{raised.exception}") + + def test_str(self) -> None: + """ + This function tests serialization of graphs to strings, either directly + or from query results. + + This function also checks that the various string serialization + overloads are correct. + """ + for format in format_infos.keys(): + format_info = format_infos[format] + + def check(data: str, check_context: bool = True) -> None: + with self.subTest(format=format, caller=inspect.stack()[1]): + self.assertIsInstance(data, str) + + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse(data=data, format=format_info.deserializer_name) + assert_graphs_equal(self, self.graph, graph_check, check_context) + + if format == "turtle": + check(self.graph.serialize()) + check(self.graph.serialize(None)) + check(self.graph.serialize(None, format)) + check(self.graph.serialize(None, format, encoding=None)) + check(self.graph.serialize(None, format, None, None)) + check(self.graph.serialize(None, format=format)) + check(self.graph.serialize(None, format=format, encoding=None)) + + if GraphType.TRIPLE not in format_info.graph_types: + # tests below are only for formats that can work with context-less graphs. + continue + + if format == "turtle": + check(self.result.serialize(), False) + check(self.result.serialize(None), False) + check(self.result.serialize(None, format=format), False) + check(self.result.serialize(None, None, format), False) + check(self.result.serialize(None, None, format=format), False) + check(self.result.serialize(None, encodin=None, format=format), False) + check( + self.result.serialize(destination=None, encoding=None, format=format), + False, + ) + + def test_bytes(self) -> None: + """ + This function tests serialization of graphs to bytes, either directly or + from query results. + + This function also checks that the various bytes serialization overloads + are correct. + """ + for (format, encoding) in itertools.chain( + *( + itertools.product({format_info.serializer_name}, format_info.encodings) + for format_info in format_infos.values() + ) + ): + format_info = format_infos[format] + + def check(data: bytes, check_context: bool = True) -> None: + with self.subTest( + format=format, encoding=encoding, caller=inspect.stack()[1] + ): + # self.check_data_bytes(data, format=format, encoding=encoding) + self.assertIsInstance(data, bytes) + + # double check that encoding is right + data_str = data.decode(encoding) + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + data=data_str, format=format_info.deserializer_name + ) + assert_graphs_equal(self, self.graph, graph_check, check_context) + + # actual check + # TODO FIXME : handle other encodings + if encoding == "utf-8": + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + data=data, format=format_info.deserializer_name + ) + assert_graphs_equal( + self, self.graph, graph_check, check_context + ) + + if format == "turtle": + check(self.graph.serialize(encoding=encoding)) + check(self.graph.serialize(None, format, encoding=encoding)) + check(self.graph.serialize(None, format, None, encoding=encoding)) + check(self.graph.serialize(None, format, encoding=encoding)) + check(self.graph.serialize(None, format=format, encoding=encoding)) + + if GraphType.TRIPLE not in format_info.graph_types: + # tests below are only for formats that can work with context-less graphs. + continue + + if format == "turtle": + check(self.result.serialize(encoding=encoding), False) + check(self.result.serialize(None, encoding), False) + check(self.result.serialize(encoding=encoding, format=format), False) + check(self.result.serialize(None, encoding, format), False) + check(self.result.serialize(None, encoding=encoding, format=format), False) + check( + self.result.serialize( + destination=None, encoding=encoding, format=format + ), + False, + ) + + def test_file(self) -> None: + """ + This function tests serialization of graphs to destinations, either directly or + from query results. + + This function also checks that the various bytes serialization overloads + are correct. + """ + dest_factory = DestinationFactory(self.tmpdir) + + for (format, encoding, dest_type) in itertools.chain( + *( + itertools.product( + {format_info.serializer_name}, + format_info.encodings, + set(DestinationType).difference({DestinationType.TEXT_IO}), + ) + for format_info in format_infos.values() + ) + ): + format_info = format_infos[format] + with ExitStack() as stack: + dest_path: Path + _dest: Union[str, Path, PurePath, IO[bytes]] + + def dest() -> Union[str, Path, PurePath, IO[bytes]]: + nonlocal dest_path + nonlocal _dest + _dest, dest_path = cast( + Tuple[Union[str, Path, PurePath, IO[bytes]], Path], + dest_factory.make(dest_type, stack), + ) + return _dest + + def _check(check_context: bool = True) -> None: + with self.subTest( + format=format, + encoding=encoding, + dest_type=dest_type, + caller=inspect.stack()[2], + ): + if isinstance(_dest, IOBase): # type: ignore[unreachable] + _dest.flush() + + source = Path(dest_path) + + # double check that encoding is right + data_str = source.read_text(encoding=encoding) + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + data=data_str, format=format_info.deserializer_name + ) + assert_graphs_equal( + self, self.graph, graph_check, check_context + ) + + self.assertTrue(source.exists()) + # actual check + # TODO FIXME : This should work for all encodings, not just utf-8 + if encoding == "utf-8": + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + source=source, format=format_info.deserializer_name + ) + assert_graphs_equal( + self, self.graph, graph_check, check_context + ) + + dest_path.unlink() + + def check_a(graph: Graph) -> None: + _check() + + if (format, encoding) == ("turtle", "utf-8"): + check_a(self.graph.serialize(dest())) + check_a(self.graph.serialize(dest(), encoding=None)) + if format == "turtle": + check_a(self.graph.serialize(dest(), encoding=encoding)) + if encoding == sys.getdefaultencoding(): + check_a(self.graph.serialize(dest(), format)) + check_a(self.graph.serialize(dest(), format, None)) + check_a(self.graph.serialize(dest(), format, None, None)) + + check_a(self.graph.serialize(dest(), format, encoding=encoding)) + check_a(self.graph.serialize(dest(), format, None, encoding)) + + if GraphType.TRIPLE not in format_info.graph_types: + # tests below are only for formats that can work with context-less graphs. + continue + + def check_b(none: None) -> None: + _check(False) + + if format == "turtle": + check_b(self.result.serialize(dest(), encoding)) + check_b( + self.result.serialize( + destination=cast(str, dest()), + encoding=encoding, + ) + ) + check_b(self.result.serialize(dest(), encoding=encoding, format=format)) + check_b( + self.result.serialize( + destination=dest(), encoding=encoding, format=format + ) + ) + check_b( + self.result.serialize( + destination=dest(), encoding=None, format=format + ) + ) + check_b(self.result.serialize(destination=dest(), format=format)) if __name__ == "__main__": diff --git a/test/test_sparql.py b/test/test_sparql.py index d6e541e163..0af55a2feb 100644 --- a/test/test_sparql.py +++ b/test/test_sparql.py @@ -240,8 +240,15 @@ def test_txtresult(): assert result.type == "SELECT" assert len(result) == 1 assert result.vars == vars - txtresult = result.serialize(format="txt") - lines = txtresult.decode().splitlines() + + bytesresult = result.serialize(format="txt", encoding="utf-8") + lines = bytesresult.decode().splitlines() + assert len(lines) == 3 + vars_check = [Variable(var.strip()) for var in lines[0].split("|")] + assert vars_check == vars + + strresult = result.serialize(format="txt") + lines = strresult.splitlines() assert len(lines) == 3 vars_check = [Variable(var.strip()) for var in lines[0].split("|")] assert vars_check == vars diff --git a/test/test_sparql_result_serialize.py b/test/test_sparql_result_serialize.py new file mode 100644 index 0000000000..0f43c78cdf --- /dev/null +++ b/test/test_sparql_result_serialize.py @@ -0,0 +1,281 @@ +from contextlib import ExitStack +import itertools +from typing import ( + IO, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Set, + TextIO, + Union, + cast, +) +from rdflib.query import Result, ResultRow +from .test_serialize import DestinationFactory, DestinationType +from test.testutils import GraphHelper +from rdflib.term import Node +import unittest +from rdflib import Graph, Namespace +from tempfile import TemporaryDirectory +from pathlib import Path, PurePath +from io import BytesIO, IOBase, StringIO +import inspect + +EG = Namespace("http://example.com/") + + +class FormatInfo(NamedTuple): + serializer_name: str + deserializer_name: str + encodings: Set[str] + + +class FormatInfos(Dict[str, FormatInfo]): + def add_format( + self, + serializer_name: str, + deserializer_name: str, + *, + encodings: Set[str], + ) -> None: + self[serializer_name] = FormatInfo( + serializer_name, + deserializer_name, + encodings, + ) + + def select( + self, + *, + name: Optional[Set[str]] = None, + ) -> Iterable[FormatInfo]: + for format in self.values(): + if name is not None and format.serializer_name in name: + yield format + + @classmethod + def make(cls) -> "FormatInfos": + result = cls() + result.add_format("csv", "csv", encodings={"utf-8"}) + result.add_format("json", "json", encodings={"utf-8"}) + result.add_format("xml", "xml", encodings={"utf-8"}) + result.add_format("txt", "txt", encodings={"utf-8"}) + + return result + + +format_infos = FormatInfos.make() + + +class ResultHelper: + @classmethod + def to_list(cls, result: Result) -> List[Dict[str, Node]]: + output: List[Dict[str, Node]] = [] + row: ResultRow + for row in result: + output.append(row.asdict()) + return output + + +def check_txt(test_case: unittest.TestCase, result: Result, data: str) -> None: + """ + This does somewhat of a smoke tests that data is the txt serialization of the + given result. This is by no means perfect but better than nothing. + """ + txt_lines = data.splitlines() + test_case.assertEqual(len(txt_lines) - 2, len(result)) + test_case.assertRegex(txt_lines[1], r"^[-]+$") + header = txt_lines[0] + test_case.assertIsNotNone(result.vars) + assert result.vars is not None + for var in result.vars: + test_case.assertIn(var, header) + for row_index, row in enumerate(result): + txt_row = txt_lines[row_index + 2] + value: Node + for key, value in row.asdict().items(): + test_case.assertIn(f"{value}", txt_row) + + +class TestSerializeSelect(unittest.TestCase): + def setUp(self) -> None: + graph = Graph() + triples = [ + (EG["e0"], EG["a0"], EG["e1"]), + (EG["e0"], EG["a0"], EG["e2"]), + (EG["e0"], EG["a0"], EG["e3"]), + (EG["e1"], EG["a1"], EG["e2"]), + (EG["e1"], EG["a1"], EG["e3"]), + (EG["e2"], EG["a2"], EG["e3"]), + ] + GraphHelper.add_triples(graph, triples) + + query = """ + PREFIX eg: + SELECT ?subject ?predicate ?object WHERE { + VALUES ?predicate { eg:a1 } + ?subject ?predicate ?object + } ORDER BY ?object + """ + self.result = graph.query(query) + self.result_table = [ + ["subject", "predicate", "object"], + ["http://example.com/e1", "http://example.com/a1", "http://example.com/e2"], + ["http://example.com/e1", "http://example.com/a1", "http://example.com/e3"], + ] + + self._tmpdir = TemporaryDirectory() + self.tmpdir = Path(self._tmpdir.name) + + return super().setUp() + + def tearDown(self) -> None: + self._tmpdir.cleanup() + + def test_str(self) -> None: + for format in format_infos.keys(): + + def check(data: str) -> None: + with self.subTest(format=format, caller=inspect.stack()[1]): + self.assertIsInstance(data, str) + format_info = format_infos[format] + if format_info.deserializer_name == "txt": + check_txt(self, self.result, data) + else: + result_check = Result.parse( + StringIO(data), format=format_info.deserializer_name + ) + self.assertEqual(self.result, result_check) + + if format == "txt": + check(self.result.serialize()) + check(self.result.serialize(None, None, None)) + check(self.result.serialize(None, None, format)) + check(self.result.serialize(format=format)) + check(self.result.serialize(destination=None, format=format)) + check(self.result.serialize(destination=None, encoding=None, format=format)) + + def test_bytes(self) -> None: + for (format, encoding) in itertools.chain( + *( + itertools.product({format_info.serializer_name}, format_info.encodings) + for format_info in format_infos.values() + ) + ): + + def check(data: bytes) -> None: + with self.subTest(format=format, caller=inspect.stack()[1]): + self.assertIsInstance(data, bytes) + format_info = format_infos[format] + if format_info.deserializer_name == "txt": + check_txt(self, self.result, data.decode(encoding)) + else: + result_check = Result.parse( + BytesIO(data), format=format_info.deserializer_name + ) + self.assertEqual(self.result, result_check) + + if format == "txt": + check(self.result.serialize(encoding=encoding)) + check(self.result.serialize(None, encoding, None)) + check(self.result.serialize(None, encoding)) + check(self.result.serialize(None, encoding, format)) + check(self.result.serialize(format=format, encoding=encoding)) + check( + self.result.serialize( + destination=None, format=format, encoding=encoding + ) + ) + check( + self.result.serialize( + destination=None, encoding=encoding, format=format + ) + ) + + def test_file(self) -> None: + + dest_factory = DestinationFactory(self.tmpdir) + + for (format, encoding, dest_type) in itertools.chain( + *( + itertools.product( + {format_info.serializer_name}, + format_info.encodings, + set(DestinationType), + ) + for format_info in format_infos.values() + ) + ): + with ExitStack() as stack: + dest_path: Path + _dest: Union[str, Path, PurePath, IO[bytes], TextIO] + + def dest() -> Union[str, Path, PurePath, IO[bytes], TextIO]: + nonlocal dest_path + nonlocal _dest + _dest, dest_path = dest_factory.make(dest_type, stack) + return _dest + + def check(none: None) -> None: + with self.subTest( + format=format, + encoding=encoding, + dest_type=dest_type, + caller=inspect.stack()[1], + ): + if isinstance(_dest, IOBase): # type: ignore[unreachable] + _dest.flush() + format_info = format_infos[format] + data_str = dest_path.read_text(encoding=encoding) + if format_info.deserializer_name == "txt": + check_txt(self, self.result, data_str) + else: + result_check = Result.parse( + StringIO(data_str), format=format_info.deserializer_name + ) + self.assertEqual(self.result, result_check) + dest_path.unlink() + + if dest_type == DestinationType.IO_BYTES: + check( + self.result.serialize( + cast(IO[bytes], dest()), + encoding, + format, + ) + ) + check( + self.result.serialize( + cast(IO[bytes], dest()), + encoding, + format=format, + ) + ) + check( + self.result.serialize( + cast(IO[bytes], dest()), + encoding=encoding, + format=format, + ) + ) + check( + self.result.serialize( + destination=cast(IO[bytes], dest()), + encoding=encoding, + format=format, + ) + ) + check( + self.result.serialize( + destination=dest(), encoding=None, format=format + ) + ) + check(self.result.serialize(destination=dest(), format=format)) + check(self.result.serialize(dest(), format=format)) + check(self.result.serialize(dest(), None, format)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_trig.py b/test/test_trig.py index 9dcd0ecca7..1afc4c8188 100644 --- a/test/test_trig.py +++ b/test/test_trig.py @@ -1,6 +1,9 @@ import unittest +from unittest.case import expectedFailure import rdflib import re +from rdflib import Namespace +from .testutils import GraphHelper from nose import SkipTest @@ -11,12 +14,59 @@ ) +EG = Namespace("http://example.com/") + + class TestTrig(unittest.TestCase): def testEmpty(self): g = rdflib.Graph() s = g.serialize(format="trig") self.assertTrue(s is not None) + def test_single_quad(self) -> None: + graph = rdflib.ConjunctiveGraph() + quad = (EG["subject"], EG["predicate"], EG["object"], EG["graph"]) + graph.add(quad) + check_graph = rdflib.ConjunctiveGraph() + data_str = graph.serialize(format="trig") + check_graph.parse(data=data_str, format="trig") + quad_set, check_quad_set = GraphHelper.quad_sets([graph, check_graph]) + self.assertEqual(quad_set, check_quad_set) + + @expectedFailure + def test_default_identifier(self) -> None: + """ + This should pass, but for some reason when the default identifier is + set, trig serializes quads inside this default indentifier to an + anonymous graph. + + So in this test, data_str is: + + @base . + @prefix ns1: . + + { + ns1:subject ns1:predicate ns1:object . + } + + instead of: + @base . + @prefix ns1: . + + ns1:graph { + ns1:subject ns1:predicate ns1:object . + } + """ + graph_id = EG["graph"] + graph = rdflib.ConjunctiveGraph(identifier=EG["graph"]) + quad = (EG["subject"], EG["predicate"], EG["object"], graph_id) + graph.add(quad) + check_graph = rdflib.ConjunctiveGraph() + data_str = graph.serialize(format="trig") + check_graph.parse(data=data_str, format="trig") + quad_set, check_quad_set = GraphHelper.quad_sets([graph, check_graph]) + self.assertEqual(quad_set, check_quad_set) + def testRepeatTriples(self): g = rdflib.ConjunctiveGraph() g.get_context("urn:a").add( diff --git a/test/test_util.py b/test/test_util.py index 76b6c51a02..aa9663adf5 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from io import BufferedIOBase, RawIOBase, TextIOBase +from typing import BinaryIO, TextIO import unittest import time from unittest.case import expectedFailure @@ -15,6 +17,9 @@ from rdflib.exceptions import PredicateTypeError from rdflib.exceptions import ObjectTypeError from rdflib.exceptions import ContextTypeError +from pathlib import Path +from tempfile import TemporaryDirectory +from rdflib.util import as_textio n3source = """\ @prefix : . @@ -44,7 +49,7 @@ :sister a rdf:Property. -:sister rdfs:domain :Person; +:sister rdfs:domain :Person; rdfs:range :Woman. :Woman = foo:FemaleAdult . @@ -391,5 +396,48 @@ def test_util_check_pattern(self): self.assertTrue(res == None) +class TestIO(unittest.TestCase): + def setUp(self) -> None: + self._tmpdir = TemporaryDirectory() + self.tmpdir = Path(self._tmpdir.name) + + return super().setUp() + + def tearDown(self) -> None: + self._tmpdir.cleanup() + + def test_as_textio_text(self) -> None: + tmp_file = self.tmpdir / "file" + with tmp_file.open("w") as text_stream: + text_io: TextIO = text_stream + assert text_io is text_stream + with as_textio(text_stream) as text_io: + assert text_io is text_stream + text_io.write("Test") + text_stream.flush() + self.assertEqual(tmp_file.read_text(), "Test") + self.assertIsInstance(text_io, TextIOBase) + + def test_as_textio_buffered_stream(self) -> None: + tmp_file = self.tmpdir / "file" + with tmp_file.open("wb") as buffered_stream: + binary_io: BinaryIO = buffered_stream + assert binary_io is buffered_stream + with as_textio(buffered_stream) as text_io: + text_io.write("Test") + self.assertEqual(tmp_file.read_text(), "Test") + self.assertIsInstance(buffered_stream, BufferedIOBase) + + def test_as_textio_raw_stream(self) -> None: + tmp_file = self.tmpdir / "file" + with tmp_file.open("wb", buffering=0) as raw_stream: + binary_io: BinaryIO = raw_stream + assert binary_io is raw_stream + with as_textio(raw_stream) as text_io: + text_io.write("Test") + self.assertEqual(tmp_file.read_text(), "Test") + self.assertIsInstance(binary_io, RawIOBase) + + if __name__ == "__main__": unittest.main() diff --git a/test/testutils.py b/test/testutils.py index 05ddf90718..5c81d8773e 100644 --- a/test/testutils.py +++ b/test/testutils.py @@ -1,4 +1,6 @@ from __future__ import print_function +from rdflib.graph import Dataset +from rdflib.plugin import Plugin import os import sys @@ -9,6 +11,7 @@ from contextlib import AbstractContextManager, contextmanager from typing import ( + Generic, Iterable, List, Optional, @@ -33,10 +36,11 @@ import unittest from rdflib import BNode, Graph, ConjunctiveGraph -from rdflib.term import Node +from rdflib.term import Node, URIRef from unittest.mock import MagicMock, Mock from urllib.error import HTTPError from urllib.request import urlopen +import rdflib.plugin if TYPE_CHECKING: import typing_extensions as te @@ -47,6 +51,32 @@ from test import TEST_DIR from test.earl import add_test, report +PluginT = TypeVar("PluginT") + + +class PluginWithNames(NamedTuple, Generic[PluginT]): + plugin: Plugin[PluginT] + names: Set[str] + + +def get_unique_plugins( + type: Type[PluginT], +) -> Dict[Type[PluginT], Set[Plugin[PluginT]]]: + result: Dict[Type[PluginT], Set[Plugin[PluginT]]] = {} + for plugin in rdflib.plugin.plugins(None, type): + cls = plugin.getClass() + plugins = result.setdefault(cls, set()) + plugins.add(plugin) + return result + + +def get_unique_plugin_names(type: Type[PluginT]) -> Set[str]: + result: Set[str] = set() + unique_plugins = get_unique_plugins(type) + for type, plugin_set in unique_plugins.items(): + result.add(next(iter(plugin_set)).name) + return result + def crapCompare(g1, g2): """A really crappy way to 'check' if two graphs are equal. It ignores blank @@ -173,6 +203,14 @@ def ctx_http_server(handler: Type[BaseHTTPRequestHandler]) -> Iterator[HTTPServe class GraphHelper: + @classmethod + def add_triples( + cls, graph: Graph, triples: Iterable[Tuple[Node, Node, Node]] + ) -> Graph: + for triple in triples: + graph.add(triple) + return graph + @classmethod def triple_set(cls, graph: Graph) -> Set[Tuple[Node, Node, Node]]: return set(graph.triples((None, None, None))) @@ -185,8 +223,25 @@ def triple_sets(cls, graphs: Iterable[Graph]) -> List[Set[Tuple[Node, Node, Node return result @classmethod - def equals(cls, lhs: Graph, rhs: Graph) -> bool: - return cls.triple_set(lhs) == cls.triple_set(rhs) + def quad_set(cls, graph: ConjunctiveGraph) -> Set[Tuple[Node, Node, Node, Node]]: + result: Set[Tuple[Node, Node, Node, Node]] = set() + for quad in graph.quads((None, None, None, None)): + c: Graph + s, p, o, c = quad + if isinstance(graph, Dataset): + result.add((s, p, o, c)) + else: + result.add((s, p, o, c.identifier)) + return result + + @classmethod + def quad_sets( + cls, graphs: Iterable[ConjunctiveGraph] + ) -> List[Set[Tuple[Node, Node, Node, Node]]]: + result: List[Set[Tuple[Node, Node, Node, Node]]] = [] + for graph in graphs: + result.append(cls.quad_set(graph)) + return result GenericT = TypeVar("GenericT", bound=Any)