Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Jul 10, 2023
1 parent db01709 commit 3ed402a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
11 changes: 8 additions & 3 deletions edspdf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,16 +588,21 @@ def collate(
return batch

def parameters(self):
"""Returns an iterator over the Pytorch parameters of the components in the
pipeline"""
return (p for n, p in self.named_parameters())

def named_parameters(self):
"""Returns an iterator over the Pytorch parameters of the components in the
pipeline"""
seen = set()
for name, component in self.pipeline:
if hasattr(component, "parameters"):
for param in component.parameters():
if hasattr(component, "named_parameters"):
for param_name, param in component.named_parameters():
if param in seen:
continue
seen.add(param)
yield param
yield f"{name}.{param_name}", param

def to(self, device: Optional[torch.device] = None):
"""Moves the pipeline to a given device"""
Expand Down
9 changes: 4 additions & 5 deletions edspdf/pipes/embeddings/huggingface_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,11 @@ def collate(self, batch, device):
collated = {
"input_ids": as_folded_tensor(batch["input_ids"], **kw, dtype=torch.long),
"bbox": as_folded_tensor(batch["bbox"], **kw, dtype=torch.long),
"windows": windows,
"indexer": indexer[line_window_indices],
"line_window_indices": indexer[line_window_indices].as_tensor(),
"line_window_offsets_flat": line_window_offsets_flat,
"windows": windows.to(device),
"indexer": indexer[line_window_indices].to(device),
"line_window_indices": indexer[line_window_indices].as_tensor().to(device),
"line_window_offsets_flat": line_window_offsets_flat.to(device),
}
print(windows_count_per_page)
if self.use_image:
collated["pixel_values"] = (
torch.stack(
Expand Down
2 changes: 1 addition & 1 deletion edspdf/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class Box(BaseModel):

@property
def page(self):
return self.doc.pages[self.page_num]
return next(p for p in self.doc.pages if p.page_num == self.page_num)

def __lt__(self, other):
self_page_num = self.page_num or 0
Expand Down

0 comments on commit 3ed402a

Please sign in to comment.