Skip to content

Commit

Permalink
refactor(wiki): index flow done
Browse files Browse the repository at this point in the history
  • Loading branch information
niuzs-alan committed Jun 11, 2021
1 parent 66694ea commit 15a7caa
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 258 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ nohup.out

# cache
enwiki-latest-abstract.xml
wiki_dump.gz
wiki_dump.gz
38 changes: 0 additions & 38 deletions distributed/wiki/annoy_indexer.yml

This file was deleted.

36 changes: 0 additions & 36 deletions distributed/wiki/chunk_indexer.yml

This file was deleted.

10 changes: 0 additions & 10 deletions distributed/wiki/chunk_merger.yml

This file was deleted.

32 changes: 0 additions & 32 deletions distributed/wiki/doc.yml

This file was deleted.

15 changes: 0 additions & 15 deletions distributed/wiki/encoder.yml

This file was deleted.

79 changes: 35 additions & 44 deletions distributed/wiki/executors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os
import json
import errno
import torch
import numpy as np
from pathlib import Path
from annoy import AnnoyIndex
from typing import Dict, Optional

import torch
from transformers import AutoModel, AutoTokenizer
from jina import Executor, DocumentArray, requests, Document
from jina.types.arrays.memmap import DocumentArrayMemmap


class Segmenter(Executor):
Expand Down Expand Up @@ -82,7 +88,13 @@ def _compute_embedding(self, hidden_states: 'torch.Tensor', input_tokens: Dict):
@requests
def encode(self, docs: 'DocumentArray', *args, **kwargs):

texts = docs.get_attributes('text')
chunks = DocumentArray(
list(
filter(lambda d: d.mime_type == 'text/plain', docs.traverse_flat(['c']))
)
)

texts = chunks.get_attributes('text')

with torch.no_grad():

Expand All @@ -107,9 +119,11 @@ def encode(self, docs: 'DocumentArray', *args, **kwargs):
hidden_states = outputs.hidden_states

embeds = self._compute_embedding(hidden_states, input_tokens)
for doc, embed in zip(docs, embeds):
for doc, embed in zip(chunks, embeds):
doc.embedding = embed

return chunks


class AnnoyIndexer(Executor):

Expand All @@ -121,7 +135,7 @@ class AnnoyIndexer(Executor):
def __init__(
self,
top_k: int = 10,
num_dim: int = 512,
num_dim: int = 768,
num_trees: int = 10,
metric: str = 'angular',
**kwargs,
Expand Down Expand Up @@ -187,47 +201,24 @@ def close(self):
json.dump(self.id_docid_map, f)


class DocIndexer(Executor):
@requests
def index(self, docs: DocumentArray, **kwargs):
pass


import numpy as np
from jina.helper import deprecated_alias


class SimpleAggregateRanker(Executor):
"""
:class:`SimpleAggregateRanker` aggregates the score
of the matched doc from the matched chunks.
For each matched doc, the score is aggregated
from all the matched chunks belonging to that doc.
:param: aggregate_function: defines the used aggregate function
and can be one of: [min, max, mean, median, sum, prod]
:param: inverse_score: plus-one inverse by 1/(1+score)
:raises:
ValueError: If `aggregate_function` is not any of the expected types
:param args: Additional positional arguments
:param kwargs: Additional keyword arguments
"""
class KeyValueIndexer(Executor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._docs = DocumentArrayMemmap(self.workspace + '/kv-idx')

AGGREGATE_FUNCTIONS = ['min', 'max', 'mean', 'median', 'sum', 'prod']
@property
def save_path(self):
if not os.path.exists(self.workspace):
os.makedirs(self.workspace)
return os.path.join(self.workspace, 'kv.json')

def __init__(
self, aggregate_function: str, inverse_score: bool = False, *args, **kwargs
):
"""Set constructor"""
super().__init__(*args, **kwargs)
self.inverse_score = inverse_score
if aggregate_function in self.AGGREGATE_FUNCTIONS:
self.np_aggregate_function = getattr(np, aggregate_function)
else:
raise ValueError(
f'The aggregate function "{aggregate_function}" is not in "{self.AGGREGATE_FUNCTIONS}".'
)
@requests(on='/index')
def index(self, docs: DocumentArray, **kwargs):
self._docs.extend(docs)

@requests(on='/search')
def score(self, docs, **kwargs):
pass
def query(self, docs: DocumentArray, **kwargs):
for doc in docs:
for match in doc.matches:
extracted_doc = self._docs[match.parent_id]
match.update(extracted_doc)
25 changes: 13 additions & 12 deletions distributed/wiki/local/index.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
jtype: Flow
version: '1'
with:
workspace: $HW_WORKDIR
workspace: $JINA_WORKDIR
py_modules:
- executors.py
pods:
- name: segmenter
# shards: {{ JINA_SEGMENTER_SHARDS }}
# scheduling: {{ JINA_SCHEDULING }}
read_only: true
timeout_ready: 100000
uses:
jtype: Segmenter
- name: encoder
# scheduling: {{ JINA_SCHEDULING }}
# shards: {{ JINA_ENCODER_SHARDS }}
timeout_ready: 100000
read_only: true
uses:
jtype: TextEncoder
- name: vec_idx
# scheduling: {{ JINA_SCHEDULING }}
# shards: {{ JINA_VEC_INDEXER_SHARDS }}
timeout_ready: 100000
uses:
jtype: AnnoyIndexer
- name: doc_idx
uses:
jtype: KeyValueIndexer
needs: segmenter
- name: join_all
needs: [vec_idx, doc_idx]
5 changes: 0 additions & 5 deletions distributed/wiki/segment.yml

This file was deleted.

38 changes: 0 additions & 38 deletions distributed/wiki/segmenters.py

This file was deleted.

27 changes: 0 additions & 27 deletions distributed/wiki/test.py

This file was deleted.

0 comments on commit 15a7caa

Please sign in to comment.