diff --git a/ahcore/callbacks.py b/ahcore/callbacks.py index 05e561a..96bb9e6 100644 --- a/ahcore/callbacks.py +++ b/ahcore/callbacks.py @@ -378,7 +378,7 @@ def _batch_end( tile_overlap = inference_grid.tile_overlap # TODO: We are really putting strange things in the Queue if we may believe mypy - new_queue: Queue[Any] = Queue() # pylint: disable=unsubscriptable-object + new_queue: Queue[Any] = Queue() # pylint: disable=unsubscriptable-object parent_conn, child_conn = Pipe() new_writer = H5FileImageWriter( output_filename, diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 0731d3d..741214d 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -6,8 +6,8 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, relationship -class CategoryEnum(PyEnum): - TRAIN = "fit" +class CategoryEnum(str, PyEnum): + FIT = "fit" VALIDATE = "validate" TEST = "test" PREDICT = "predict" @@ -44,7 +44,7 @@ class Patient(Base): manifest: Mapped["Manifest"] = relationship("Manifest", back_populates="patients") images: Mapped[List["Image"]] = relationship("Image", back_populates="patient") labels: Mapped[List["PatientLabels"]] = relationship("PatientLabels", back_populates="patient") - split: Mapped[List["Split"]] = relationship("Split", uselist=False, back_populates="patient") + splits: Mapped[List["Split"]] = relationship("Split", back_populates="patient") class Image(Base): @@ -55,9 +55,9 @@ class Image(Base): # pylint: disable=E1102 created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - filename = Column(String, unique=True) + filename = Column(String, unique=True, nullable=False) reader = Column(String) - patient_id = Column(Integer, ForeignKey("patient.id")) + patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False) height = Column(Integer) width = Column(Integer) @@ -66,8 +66,8 @@ class Image(Base): patient: Mapped["Patient"] = relationship("Patient", back_populates="images") masks: Mapped[List["Mask"]] = relationship("Mask", back_populates="image") annotations: Mapped[List["ImageAnnotations"]] = relationship("ImageAnnotations", back_populates="image") - labels: Mapped["ImageLabels"] = relationship("ImageLabels", back_populates="image") - cache: Mapped["ImageCache"] = relationship("ImageCache", uselist=False, back_populates="image") + labels: Mapped[List["ImageLabels"]] = relationship("ImageLabels", back_populates="image") + caches: Mapped[List["ImageCache"]] = relationship("ImageCache", back_populates="image") class ImageCache(Base): @@ -79,14 +79,14 @@ class ImageCache(Base): # pylint: disable=E1102 created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - filename = Column(String, unique=True) + filename = Column(String, unique=True, nullable=False) reader = Column(String) num_tiles = Column(Integer) - image_id = Column(Integer, ForeignKey("image.id")) - - image: Mapped["Image"] = relationship("Image", back_populates="cache") + image_id = Column(Integer, ForeignKey("image.id"), nullable=False) description_id = Column(Integer, ForeignKey("cache_description.id")) - description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="caches") + + image: Mapped["Image"] = relationship("Image", back_populates="caches") + description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="cache") class CacheDescription(Base): @@ -108,7 +108,7 @@ class CacheDescription(Base): mask_threshold = Column(Float) grid_order = Column(String) - caches: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description") + cache: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description") class Mask(Base): @@ -121,7 +121,7 @@ class Mask(Base): last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) filename = Column(String, unique=True) reader = Column(String) - image_id = Column(Integer, ForeignKey("image.id")) + image_id = Column(Integer, ForeignKey("image.id"), nullable=False) image: Mapped["Image"] = relationship("Image", back_populates="masks") @@ -136,7 +136,7 @@ class ImageAnnotations(Base): last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) filename = Column(String, unique=True) reader = Column(String) - image_id = Column(Integer, ForeignKey("image.id")) + image_id = Column(Integer, ForeignKey("image.id"), nullable=False) image: Mapped["Image"] = relationship("Image", back_populates="annotations") @@ -149,11 +149,14 @@ class ImageLabels(Base): # pylint: disable=E1102 created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - label_data = Column(String) # e.g. "cancer" or "benign" - image_id = Column(Integer, ForeignKey("image.id")) + key = Column(String, nullable=False) + value = Column(String, nullable=False) + image_id = Column(Integer, ForeignKey("image.id"), nullable=False) image: Mapped["Image"] = relationship("Image", back_populates="labels") + __table_args__ = (UniqueConstraint("key", "image_id", name="uq_image_label_key"),) + class PatientLabels(Base): """Patient labels table.""" @@ -163,15 +166,14 @@ class PatientLabels(Base): # pylint: disable=E1102 created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - key = Column(String) - value = Column(String) - patient_id = Column(Integer, ForeignKey("patient.id")) - - # Add a unique constraint - __table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),) + key = Column(String, nullable=False) + value = Column(String, nullable=False) + patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False) patient: Mapped["Patient"] = relationship("Patient", back_populates="labels") + __table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),) + class SplitDefinitions(Base): """Split definitions table.""" @@ -181,8 +183,9 @@ class SplitDefinitions(Base): # pylint: disable=E1102 created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - version = Column(String, nullable=False) + version = Column(String, nullable=False, unique=True) description = Column(String) + splits: Mapped[List["Split"]] = relationship("Split", back_populates="split_definition") @@ -196,9 +199,10 @@ class Split(Base): created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) category: Column[CategoryEnum] = Column(Enum(CategoryEnum), nullable=False) - - patient_id = Column(Integer, ForeignKey("patient.id")) - patient: Mapped["Patient"] = relationship("Patient", back_populates="split") - + patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False) split_definition_id = Column(Integer, ForeignKey("split_definitions.id"), nullable=False) + + patient: Mapped["Patient"] = relationship("Patient", back_populates="splits") split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits") + + __table_args__ = (UniqueConstraint("split_definition_id", "patient_id", name="uq_patient_split_key"),) diff --git a/tools/populate_tcga_db.py b/tools/populate_tcga_db.py index 1be57b0..ed95878 100644 --- a/tools/populate_tcga_db.py +++ b/tools/populate_tcga_db.py @@ -66,7 +66,7 @@ def populate_from_annotated_tcga( split_category = CategoryEnum.PREDICT else: split_category = random.choices( - [CategoryEnum.TRAIN, CategoryEnum.VALIDATE, CategoryEnum.TEST], + [CategoryEnum.FIT, CategoryEnum.VALIDATE, CategoryEnum.TEST], [67, 33, 0], )[0] @@ -118,8 +118,10 @@ def populate_from_annotated_tcga( image_annotation = ImageAnnotations(filename=str(annotation_path), reader="GEOJSON", image=image) session.add(image_annotation) - label_data = "cancer" if random.choice([True, False]) else "benign" # Randomly decide if it's cancer or benign - image_label = ImageLabels(label_data=label_data, image=image) + # Randomly decide if it's cancer or benign + image_label = ImageLabels( + key="tumor_type", value="cancer" if random.choice([True, False]) else "benign", image=image + ) session.add(image_label) session.commit() @@ -129,5 +131,5 @@ def populate_from_annotated_tcga( annotation_folder = Path("tissue_subtypes/v20230228_debug/") image_folder = Path("/data/groups/aiforoncology/archive/pathology/TCGA/images/") path_to_mapping = Path("/data/groups/aiforoncology/archive/pathology/TCGA/identifier_mapping.json") - with open_db("manifest.db") as session: + with open_db("sqlite:///manifest.db") as session: populate_from_annotated_tcga(session, image_folder, annotation_folder, path_to_mapping, predict=True)