Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SnippetGenerator for highlighting #36

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crate-type = ["cdylib"]

[dependencies]
chrono = "0.4.19"
tantivy = "0.13.2"
tantivy = "0.16.1"
itertools = "0.9.0"
futures = "0.3.5"

Expand All @@ -22,4 +22,4 @@ features = ["extension-module"]

[package.metadata.maturin]
requires-python = ">=3.7"
project-url = { Source = "https://github.com/quickwit-inc/tantivy-py" }
project-url = ["https://github.com/quickwit-inc/tantivy-py"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaics this change breaks building with maturin tooling

❯ maturin build
💥 maturin failed
  Caused by: Failed to parse Cargo.toml at Cargo.toml
  Caused by: invalid type: sequence, expected a map for key `package.metadata.maturin.project-url` at line 23 column 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgreene Can you confirm and fix?

2 changes: 1 addition & 1 deletion src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ impl Document {
}

impl Document {
fn iter_values_for_field<'a>(
pub(crate) fn iter_values_for_field<'a>(
&'a self,
field: &str,
) -> impl Iterator<Item = &'a Value> + 'a {
Expand Down
12 changes: 8 additions & 4 deletions src/facet.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use pyo3::{basic::PyObjectProtocol, prelude::*, types::PyType};
use tantivy::schema;
use crate::{
to_pyerr,
};

/// A Facet represent a point in a given hierarchy.
///
Expand Down Expand Up @@ -46,10 +49,11 @@ impl Facet {
///
/// Returns the created Facet.
#[classmethod]
fn from_string(_cls: &PyType, facet_string: &str) -> Facet {
Facet {
inner: schema::Facet::from_text(facet_string),
}
fn from_string(_cls: &PyType, facet_string: &str) -> PyResult<Facet> {
let inner = schema::Facet::from_text(facet_string).map_err(to_pyerr)?;
Ok(Facet {
inner: inner,
})
}

/// Returns the list of `segments` that forms a facet path.
Expand Down
5 changes: 3 additions & 2 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl Index {
if reuse {
tv::Index::open_or_create(directory, schema.inner.clone())
} else {
tv::Index::create(directory, schema.inner.clone())
tv::Index::create(directory, schema.inner.clone(), tv::IndexSettings::default())
}
.map_err(to_pyerr)?
}
Expand Down Expand Up @@ -277,7 +277,8 @@ impl Index {
#[staticmethod]
fn exists(path: &str) -> PyResult<bool> {
let directory = MmapDirectory::open(path).map_err(to_pyerr)?;
Ok(tv::Index::exists(&directory))
let exists = tv::Index::exists(&directory).map_err(to_pyerr)?;
Ok(exists)
}

/// The schema of the current index.
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ mod query;
mod schema;
mod schemabuilder;
mod searcher;
mod snippet;

use document::Document;
use facet::Facet;
use index::Index;
use query::Query;
use schema::Schema;
use schemabuilder::SchemaBuilder;
use searcher::{DocAddress, Searcher};
use snippet::{SnippetGenerator, Snippet};

/// Python bindings for the search engine library Tantivy.
///
Expand Down Expand Up @@ -75,6 +78,9 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Index>()?;
m.add_class::<DocAddress>()?;
m.add_class::<Facet>()?;
m.add_class::<Query>()?;
m.add_class::<Snippet>()?;
m.add_class::<SnippetGenerator>()?;
Ok(())
}

Expand Down
64 changes: 60 additions & 4 deletions src/schemabuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,17 @@ impl SchemaBuilder {
/// Add a Facet field to the schema.
/// Args:
/// name (str): The name of the field.
fn add_facet_field(&mut self, name: &str) -> PyResult<Self> {
#[args(stored = false, indexed = false)]
fn add_facet_field(&mut self,
name: &str,
stored: bool,
indexed: bool) -> PyResult<Self> {
let builder = &mut self.builder;

let opts = SchemaBuilder::build_facet_option(stored, indexed)?;

if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_facet_field(name);
builder.add_facet_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
Expand All @@ -253,11 +259,17 @@ impl SchemaBuilder {
///
/// Args:
/// name (str): The name of the field.
fn add_bytes_field(&mut self, name: &str) -> PyResult<Self> {
#[args(stored = false, indexed = false)]
fn add_bytes_field(&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: Option<&str>) -> PyResult<Self> {
let builder = &mut self.builder;

let opts = SchemaBuilder::build_bytes_option(stored, indexed, fast)?;
if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_bytes_field(name);
builder.add_bytes_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
Expand All @@ -284,6 +296,50 @@ impl SchemaBuilder {
}

impl SchemaBuilder {
fn build_facet_option(
stored: bool,
indexed: bool,
) -> PyResult<schema::FacetOptions> {
let opts = schema::FacetOptions::default();

let opts = if stored { opts.set_stored() } else { opts };
let opts = if indexed { opts.set_indexed() } else { opts };
Ok(opts)
}

fn build_bytes_option(
stored: bool,
indexed: bool,
fast: Option<&str>,
) -> PyResult<schema::BytesOptions> {
let opts = schema::BytesOptions::default();

let opts = if stored { opts.set_stored() } else { opts };
let opts = if indexed { opts.set_indexed() } else { opts };

let fast = match fast {
Some(f) => {
let f = f.to_lowercase();
match f.as_ref() {
"single" => Some(schema::Cardinality::SingleValue),
"multi" => Some(schema::Cardinality::MultiValues),
_ => return Err(exceptions::PyValueError::new_err(
"Invalid index option, valid choices are: 'multivalue' and 'singlevalue'"
)),
}
}
None => None,
};

let opts = if let Some(_f) = fast {
opts.set_fast()
} else {
opts
};

Ok(opts)
}

fn build_int_option(
stored: bool,
indexed: bool,
Expand Down
12 changes: 6 additions & 6 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ impl Searcher {
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct DocAddress {
pub(crate) segment_ord: tv::SegmentLocalId,
pub(crate) doc: tv::DocId,
pub(crate) segment_ord: tv::SegmentOrdinal,
pub(crate) doc_id: tv::DocId,
fulmicoton marked this conversation as resolved.
Show resolved Hide resolved
}

#[pymethods]
Expand All @@ -212,22 +212,22 @@ impl DocAddress {
/// The segment local DocId
#[getter]
fn doc(&self) -> u32 {
self.doc
self.doc_id
}
}

impl From<&tv::DocAddress> for DocAddress {
fn from(doc_address: &tv::DocAddress) -> Self {
DocAddress {
segment_ord: doc_address.segment_ord(),
doc: doc_address.doc(),
segment_ord: doc_address.segment_ord,
doc_id: doc_address.doc_id,
}
}
}

impl Into<tv::DocAddress> for &DocAddress {
fn into(self) -> tv::DocAddress {
tv::DocAddress(self.segment_ord(), self.doc())
tv::DocAddress { segment_ord: self.segment_ord, doc_id: self.doc_id }
}
}

Expand Down
69 changes: 69 additions & 0 deletions src/snippet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use pyo3::prelude::*;
use tantivy as tv;
use crate::{
to_pyerr,
};

/// Tantivy schema.
///
/// The schema is very strict. To build the schema the `SchemaBuilder` class is
/// provided.
#[pyclass]
pub(crate) struct Snippet {
pub(crate) inner: tv::Snippet,
}

#[pyclass]
pub(crate) struct Range {
#[pyo3(get)]
start: usize,
#[pyo3(get)]
end: usize
}

#[pymethods]
impl Snippet {
pub fn to_html(&self) -> PyResult<String> {
Ok(self.inner.to_html())
}

pub fn highlighted(&self) -> Vec<Range> {
let highlighted = self.inner.highlighted();
let results = highlighted.iter().map(|r| Range { start: r.start, end: r.end }).collect::<Vec<_>>();
results
}
}


#[pyclass]
pub(crate) struct SnippetGenerator {
pub(crate) field_name: String,
pub(crate) inner: tv::SnippetGenerator,
}

#[pymethods]
impl SnippetGenerator {
#[staticmethod]
pub fn create(
searcher: &crate::Searcher,
query: &crate::Query,
schema: &crate::Schema,
field_name: &str
) -> PyResult<SnippetGenerator> {
let field = schema.inner.get_field(field_name).ok_or("field not found").map_err(to_pyerr)?;
let generator = tv::SnippetGenerator::create(&*searcher.inner, query.get(), field).map_err(to_pyerr)?;

return Ok(SnippetGenerator { field_name: field_name.to_string(), inner: generator });
}

pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet {
let text: String = doc
.iter_values_for_field(&self.field_name)
.flat_map(tv::schema::Value::text)
.collect::<Vec<&str>>()
.join(" ");

let result = self.inner.snippet(&text);
Snippet { inner: result }
}
}
26 changes: 25 additions & 1 deletion tests/tantivy_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tantivy
import pytest

from tantivy import Document, Index, SchemaBuilder, Schema
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator


def schema():
Expand Down Expand Up @@ -322,3 +322,27 @@ def test_document_with_facet(self):
def test_document_error(self):
with pytest.raises(ValueError):
tantivy.Document(name={})


class TestSnippets(object):
def test_document_snippet(self, dir_index):
index_dir, _ = dir_index
doc_schema = schema()
index = Index(doc_schema, str(index_dir))
query = index.parse_query("sea whale", ["title", "body"])
searcher = index.searcher()
result = searcher.search(query)
assert len(result.hits) == 1

snippet_generator = SnippetGenerator.create(searcher, query, doc_schema, "title")

for (score, doc_address) in result.hits:
doc = searcher.doc(doc_address)
snippet = snippet_generator.snippet_from_doc(doc)
highlights = snippet.highlighted()
assert len(highlights) == 1
first = highlights[0]
assert first.start == 20
assert first.end == 23
html_snippet = snippet.to_html()
assert html_snippet == 'The Old Man and the <b>Sea</b>'