diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py index 8944ade1..4f66824d 100644 --- a/edspdf/pipeline.py +++ b/edspdf/pipeline.py @@ -610,7 +610,6 @@ def to(self, device: Optional[torch.device] = None): component.to(device) return self - @contextmanager def train(self, mode=True): """ Enables training mode on pytorch modules @@ -621,12 +620,19 @@ def train(self, mode=True): Whether to enable training or not """ + class context: + def __enter__(self): + pass + + def __exit__(ctx_self, type, value, traceback): + for name, proc in self.trainable_pipes(): + proc.train(was_training[name]) + was_training = {name: proc.training for name, proc in self.trainable_pipes()} for name, proc in self.trainable_pipes(): proc.train(mode) - yield - for name, proc in self.trainable_pipes(): - proc.train(was_training[name]) + + return context() def score(self, docs: Sequence[PDFDoc], batch_size: int = None) -> Dict[str, Any]: """ diff --git a/edspdf/pipes/embeddings/box_layout_preprocessor.py b/edspdf/pipes/embeddings/box_layout_preprocessor.py index 42363ec4..4e1d98c3 100644 --- a/edspdf/pipes/embeddings/box_layout_preprocessor.py +++ b/edspdf/pipes/embeddings/box_layout_preprocessor.py @@ -10,7 +10,6 @@ BoxLayoutBatch = TypedDict( "BoxLayoutBatch", { - "page": FoldedTensor, "xmin": FoldedTensor, "ymin": FoldedTensor, "xmax": FoldedTensor, @@ -62,10 +61,9 @@ def __init__( def preprocess(self, doc: PDFDoc, supervision: bool = False): pages = doc.pages - box_pages = [[b.page_num for b in page.text_boxes] for page in pages] - last_p = max((p for x in box_pages for p in x), default=0) + [[b.page_num for b in page.text_boxes] for page in pages] + last_p = doc.num_pages - 1 return { - "page": box_pages, "xmin": [[b.x0 for b in p.text_boxes] for p in pages], "ymin": [[b.y0 for b in p.text_boxes] for p in pages], "xmax": [[b.x1 for b in p.text_boxes] for p in pages], @@ -84,7 +82,6 @@ def collate(self, batch, device: torch.device) -> BoxLayoutBatch: } return { - "page": as_folded_tensor(batch["page"], dtype=torch.float, **kw), "xmin": as_folded_tensor(batch["xmin"], dtype=torch.float, **kw), "ymin": as_folded_tensor(batch["ymin"], dtype=torch.float, **kw), "xmax": as_folded_tensor(batch["xmax"], dtype=torch.float, **kw), diff --git a/edspdf/pipes/embeddings/simple_text_embedding.py b/edspdf/pipes/embeddings/simple_text_embedding.py index d849947a..71a2c5e9 100644 --- a/edspdf/pipes/embeddings/simple_text_embedding.py +++ b/edspdf/pipes/embeddings/simple_text_embedding.py @@ -209,7 +209,8 @@ def preprocess(self, doc: PDFDoc): words = [m.group(0) for m in self.word_regex.finditer(b.text)] for word in words: - ascii_str = anyascii(word) + # ascii_str = unidecode.unidecode(word) + ascii_str = anyascii(word).strip() tokens_shape[-1][i].append( self.shape_voc.encode(word_shape(ascii_str)) ) @@ -253,7 +254,7 @@ def forward(self, batch: BoxTextEmbeddingInputBatch) -> EmbeddingOutput: self.shape_embedding(batch["tokens_shape"].as_tensor()) + self.prefix_embedding(batch["tokens_prefix"].as_tensor()) + self.suffix_embedding(batch["tokens_suffix"].as_tensor()) - + self.norm_embedding(batch["tokens_norm"].as_tensor()) + # + self.norm_embedding(batch["tokens_norm"].as_tensor()) ) return {"embeddings": batch["tokens_shape"].with_data(text_embeds)} diff --git a/edspdf/pipes/embeddings/sub_box_cnn_pooler.py b/edspdf/pipes/embeddings/sub_box_cnn_pooler.py index 849123f1..a917281e 100644 --- a/edspdf/pipes/embeddings/sub_box_cnn_pooler.py +++ b/edspdf/pipes/embeddings/sub_box_cnn_pooler.py @@ -98,10 +98,12 @@ def forward(self, batch: Any) -> EmbeddingOutput: dim=2, ) pooled = box_token_embeddings.max(1).values + pooled = self.linear(pooled) + # print("TEXT EMBEDS", pooled.shape, pooled.sum()) return { "embeddings": as_folded_tensor( - data=self.linear(pooled), + data=pooled, lengths=embeddings.lengths[:-1], # pooled on the last dim data_dims=["line"], # fully flattened full_names=["sample", "page", "line"], diff --git a/edspdf/structures.py b/edspdf/structures.py index ba632bee..a8ab615a 100644 --- a/edspdf/structures.py +++ b/edspdf/structures.py @@ -89,6 +89,7 @@ class PDFDoc(BaseModel): content: bytes = attrs.field(repr=lambda c: f"{len(c)} bytes") id: str = None + num_pages: int = 0 pages: List["Page"] = attrs.field(factory=list) error: bool = False content_boxes: List[Union["TextBox"]] = attrs.field(factory=list) diff --git a/edspdf/utils/torch.py b/edspdf/utils/torch.py index 29ac115d..634c2ec2 100644 --- a/edspdf/utils/torch.py +++ b/edspdf/utils/torch.py @@ -34,14 +34,25 @@ def compute_pdf_relative_positions(x0, y0, x1, y1, width, height, n_relative_pos torch.LongTensor Shape: n_pages * n_boxes * n_boxes * 2 """ - dx = x0[:, None, :] - x0[:, :, None] # B begin -> A begin - dx = (dx * n_relative_positions).long() + dx0 = x1[:, None, :] - x0[:, :, None] # B end -> A begin + dx1 = x0[:, None, :] - x1[:, :, None] # B begin -> A end + dx = ( + torch.where( + (dx0.sign() > 0) & (dx1.sign() > 0), + torch.minimum(dx0, dx1), + torch.where( + (dx0.sign() < 0) & (dx1.sign() < 0), torch.maximum(dx0, dx1), 0 + ), + ) + / 2 + * n_relative_positions + ).long() dy = y0[:, None, :] - y0[:, :, None] # If query above (dy > 0) key, use query height - ref_height = (dy >= 0).float() * height.float()[:, :, None] + ( - dy < 0 - ).float() * height[:, None, :] + ref_height = (dy >= 0).float() * height[:, :, None] + (dy < 0).float() * height[ + :, None, : + ] dy0 = y1[:, None, :] - y0[:, :, None] # A begin -> B end dy1 = y0[:, None, :] - y1[:, :, None] # A end -> B begin offset = 0.5 @@ -55,8 +66,7 @@ def compute_pdf_relative_positions(x0, y0, x1, y1, width, height, n_relative_pos (torch.maximum(dy0, dy1) / ref_height - offset).floor(), 0, ), - ) - dy = (dy.abs().ceil() * dy.sign()).long() + ).long() relative_positions = torch.stack([dx, dy], dim=-1) diff --git a/pyproject.toml b/pyproject.toml index 4830223b..09e25a4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "pdfminer.six>=20220319", "pypdfium2~=2.7", "rich-logger>=0.3.0,<1.0.0", - "safetensors~=0.3.1" + "safetensors~=0.3.1", + "anyascii>=0.3.2", ] [project.optional-dependencies]