diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py index d1652740..8944ade1 100644 --- a/edspdf/pipeline.py +++ b/edspdf/pipeline.py @@ -588,16 +588,21 @@ def collate( return batch def parameters(self): + """Returns an iterator over the Pytorch parameters of the components in the + pipeline""" + return (p for n, p in self.named_parameters()) + + def named_parameters(self): """Returns an iterator over the Pytorch parameters of the components in the pipeline""" seen = set() for name, component in self.pipeline: - if hasattr(component, "parameters"): - for param in component.parameters(): + if hasattr(component, "named_parameters"): + for param_name, param in component.named_parameters(): if param in seen: continue seen.add(param) - yield param + yield f"{name}.{param_name}", param def to(self, device: Optional[torch.device] = None): """Moves the pipeline to a given device""" diff --git a/edspdf/pipes/embeddings/huggingface_embedding.py b/edspdf/pipes/embeddings/huggingface_embedding.py index 39544c4e..b30c7dfe 100644 --- a/edspdf/pipes/embeddings/huggingface_embedding.py +++ b/edspdf/pipes/embeddings/huggingface_embedding.py @@ -256,12 +256,11 @@ def collate(self, batch, device): collated = { "input_ids": as_folded_tensor(batch["input_ids"], **kw, dtype=torch.long), "bbox": as_folded_tensor(batch["bbox"], **kw, dtype=torch.long), - "windows": windows, - "indexer": indexer[line_window_indices], - "line_window_indices": indexer[line_window_indices].as_tensor(), - "line_window_offsets_flat": line_window_offsets_flat, + "windows": windows.to(device), + "indexer": indexer[line_window_indices].to(device), + "line_window_indices": indexer[line_window_indices].as_tensor().to(device), + "line_window_offsets_flat": line_window_offsets_flat.to(device), } - print(windows_count_per_page) if self.use_image: collated["pixel_values"] = ( torch.stack( diff --git a/edspdf/structures.py b/edspdf/structures.py index 46629a64..ba632bee 100644 --- a/edspdf/structures.py +++ b/edspdf/structures.py @@ -201,7 +201,7 @@ class Box(BaseModel): @property def page(self): - return self.doc.pages[self.page_num] + return next(p for p in self.doc.pages if p.page_num == self.page_num) def __lt__(self, other): self_page_num = self.page_num or 0