Skip to content

Commit

Permalink
fix: edspdf 0.7.0 regression vs article
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Jul 26, 2023
1 parent 78f86e7 commit dd51e50
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 20 deletions.
14 changes: 10 additions & 4 deletions edspdf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ def to(self, device: Optional[torch.device] = None):
component.to(device)
return self

@contextmanager
def train(self, mode=True):
"""
Enables training mode on pytorch modules
Expand All @@ -621,12 +620,19 @@ def train(self, mode=True):
Whether to enable training or not
"""

class context:
def __enter__(self):
pass

def __exit__(ctx_self, type, value, traceback):
for name, proc in self.trainable_pipes():
proc.train(was_training[name])

was_training = {name: proc.training for name, proc in self.trainable_pipes()}
for name, proc in self.trainable_pipes():
proc.train(mode)
yield
for name, proc in self.trainable_pipes():
proc.train(was_training[name])

return context()

def score(self, docs: Sequence[PDFDoc], batch_size: int = None) -> Dict[str, Any]:
"""
Expand Down
7 changes: 2 additions & 5 deletions edspdf/pipes/embeddings/box_layout_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
BoxLayoutBatch = TypedDict(
"BoxLayoutBatch",
{
"page": FoldedTensor,
"xmin": FoldedTensor,
"ymin": FoldedTensor,
"xmax": FoldedTensor,
Expand Down Expand Up @@ -62,10 +61,9 @@ def __init__(

def preprocess(self, doc: PDFDoc, supervision: bool = False):
pages = doc.pages
box_pages = [[b.page_num for b in page.text_boxes] for page in pages]
last_p = max((p for x in box_pages for p in x), default=0)
[[b.page_num for b in page.text_boxes] for page in pages]
last_p = doc.num_pages - 1
return {
"page": box_pages,
"xmin": [[b.x0 for b in p.text_boxes] for p in pages],
"ymin": [[b.y0 for b in p.text_boxes] for p in pages],
"xmax": [[b.x1 for b in p.text_boxes] for p in pages],
Expand All @@ -84,7 +82,6 @@ def collate(self, batch, device: torch.device) -> BoxLayoutBatch:
}

return {
"page": as_folded_tensor(batch["page"], dtype=torch.float, **kw),
"xmin": as_folded_tensor(batch["xmin"], dtype=torch.float, **kw),
"ymin": as_folded_tensor(batch["ymin"], dtype=torch.float, **kw),
"xmax": as_folded_tensor(batch["xmax"], dtype=torch.float, **kw),
Expand Down
5 changes: 3 additions & 2 deletions edspdf/pipes/embeddings/simple_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def preprocess(self, doc: PDFDoc):
words = [m.group(0) for m in self.word_regex.finditer(b.text)]

for word in words:
ascii_str = anyascii(word)
# ascii_str = unidecode.unidecode(word)
ascii_str = anyascii(word).strip()
tokens_shape[-1][i].append(
self.shape_voc.encode(word_shape(ascii_str))
)
Expand Down Expand Up @@ -253,7 +254,7 @@ def forward(self, batch: BoxTextEmbeddingInputBatch) -> EmbeddingOutput:
self.shape_embedding(batch["tokens_shape"].as_tensor())
+ self.prefix_embedding(batch["tokens_prefix"].as_tensor())
+ self.suffix_embedding(batch["tokens_suffix"].as_tensor())
+ self.norm_embedding(batch["tokens_norm"].as_tensor())
# + self.norm_embedding(batch["tokens_norm"].as_tensor())
)

return {"embeddings": batch["tokens_shape"].with_data(text_embeds)}
4 changes: 3 additions & 1 deletion edspdf/pipes/embeddings/sub_box_cnn_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ def forward(self, batch: Any) -> EmbeddingOutput:
dim=2,
)
pooled = box_token_embeddings.max(1).values
pooled = self.linear(pooled)
# print("TEXT EMBEDS", pooled.shape, pooled.sum())

return {
"embeddings": as_folded_tensor(
data=self.linear(pooled),
data=pooled,
lengths=embeddings.lengths[:-1], # pooled on the last dim
data_dims=["line"], # fully flattened
full_names=["sample", "page", "line"],
Expand Down
1 change: 1 addition & 0 deletions edspdf/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class PDFDoc(BaseModel):

content: bytes = attrs.field(repr=lambda c: f"{len(c)} bytes")
id: str = None
num_pages: int = 0
pages: List["Page"] = attrs.field(factory=list)
error: bool = False
content_boxes: List[Union["TextBox"]] = attrs.field(factory=list)
Expand Down
24 changes: 17 additions & 7 deletions edspdf/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,25 @@ def compute_pdf_relative_positions(x0, y0, x1, y1, width, height, n_relative_pos
torch.LongTensor
Shape: n_pages * n_boxes * n_boxes * 2
"""
dx = x0[:, None, :] - x0[:, :, None] # B begin -> A begin
dx = (dx * n_relative_positions).long()
dx0 = x1[:, None, :] - x0[:, :, None] # B end -> A begin
dx1 = x0[:, None, :] - x1[:, :, None] # B begin -> A end
dx = (
torch.where(
(dx0.sign() > 0) & (dx1.sign() > 0),
torch.minimum(dx0, dx1),
torch.where(
(dx0.sign() < 0) & (dx1.sign() < 0), torch.maximum(dx0, dx1), 0
),
)
/ 2
* n_relative_positions
).long()

dy = y0[:, None, :] - y0[:, :, None]
# If query above (dy > 0) key, use query height
ref_height = (dy >= 0).float() * height.float()[:, :, None] + (
dy < 0
).float() * height[:, None, :]
ref_height = (dy >= 0).float() * height[:, :, None] + (dy < 0).float() * height[
:, None, :
]
dy0 = y1[:, None, :] - y0[:, :, None] # A begin -> B end
dy1 = y0[:, None, :] - y1[:, :, None] # A end -> B begin
offset = 0.5
Expand All @@ -55,8 +66,7 @@ def compute_pdf_relative_positions(x0, y0, x1, y1, width, height, n_relative_pos
(torch.maximum(dy0, dy1) / ref_height - offset).floor(),
0,
),
)
dy = (dy.abs().ceil() * dy.sign()).long()
).long()

relative_positions = torch.stack([dx, dy], dim=-1)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"pdfminer.six>=20220319",
"pypdfium2~=2.7",
"rich-logger>=0.3.0,<1.0.0",
"safetensors~=0.3.1"
"safetensors~=0.3.1",
"anyascii>=0.3.2",
]

[project.optional-dependencies]
Expand Down

0 comments on commit dd51e50

Please sign in to comment.