Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Improve database models #14

Merged
merged 3 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ahcore/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _batch_end(
tile_overlap = inference_grid.tile_overlap

# TODO: We are really putting strange things in the Queue if we may believe mypy
new_queue: Queue[Any] = Queue() # pylint: disable=unsubscriptable-object
new_queue: Queue[Any] = Queue() # pylint: disable=unsubscriptable-object
parent_conn, child_conn = Pipe()
new_writer = H5FileImageWriter(
output_filename,
Expand Down
60 changes: 32 additions & 28 deletions ahcore/utils/database_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, relationship


class CategoryEnum(PyEnum):
TRAIN = "fit"
class CategoryEnum(str, PyEnum):
FIT = "fit"
VALIDATE = "validate"
TEST = "test"
PREDICT = "predict"
Expand Down Expand Up @@ -44,7 +44,7 @@ class Patient(Base):
manifest: Mapped["Manifest"] = relationship("Manifest", back_populates="patients")
images: Mapped[List["Image"]] = relationship("Image", back_populates="patient")
labels: Mapped[List["PatientLabels"]] = relationship("PatientLabels", back_populates="patient")
split: Mapped[List["Split"]] = relationship("Split", uselist=False, back_populates="patient")
splits: Mapped[List["Split"]] = relationship("Split", back_populates="patient")


class Image(Base):
Expand All @@ -55,9 +55,9 @@ class Image(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
filename = Column(String, unique=True, nullable=False)
reader = Column(String)
patient_id = Column(Integer, ForeignKey("patient.id"))
patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)

height = Column(Integer)
width = Column(Integer)
Expand All @@ -66,8 +66,8 @@ class Image(Base):
patient: Mapped["Patient"] = relationship("Patient", back_populates="images")
masks: Mapped[List["Mask"]] = relationship("Mask", back_populates="image")
annotations: Mapped[List["ImageAnnotations"]] = relationship("ImageAnnotations", back_populates="image")
labels: Mapped["ImageLabels"] = relationship("ImageLabels", back_populates="image")
cache: Mapped["ImageCache"] = relationship("ImageCache", uselist=False, back_populates="image")
labels: Mapped[List["ImageLabels"]] = relationship("ImageLabels", back_populates="image")
caches: Mapped[List["ImageCache"]] = relationship("ImageCache", back_populates="image")


class ImageCache(Base):
Expand All @@ -79,14 +79,14 @@ class ImageCache(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
filename = Column(String, unique=True, nullable=False)
reader = Column(String)
num_tiles = Column(Integer)
image_id = Column(Integer, ForeignKey("image.id"))

image: Mapped["Image"] = relationship("Image", back_populates="cache")
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)
description_id = Column(Integer, ForeignKey("cache_description.id"))
description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="caches")

image: Mapped["Image"] = relationship("Image", back_populates="caches")
description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="cache")


class CacheDescription(Base):
Expand All @@ -108,7 +108,7 @@ class CacheDescription(Base):
mask_threshold = Column(Float)
grid_order = Column(String)

caches: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description")
cache: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description")


class Mask(Base):
Expand All @@ -121,7 +121,7 @@ class Mask(Base):
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
reader = Column(String)
image_id = Column(Integer, ForeignKey("image.id"))
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)

image: Mapped["Image"] = relationship("Image", back_populates="masks")

Expand All @@ -136,7 +136,7 @@ class ImageAnnotations(Base):
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
reader = Column(String)
image_id = Column(Integer, ForeignKey("image.id"))
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)

image: Mapped["Image"] = relationship("Image", back_populates="annotations")

Expand All @@ -149,11 +149,14 @@ class ImageLabels(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
label_data = Column(String) # e.g. "cancer" or "benign"
image_id = Column(Integer, ForeignKey("image.id"))
key = Column(String, nullable=False)
value = Column(String, nullable=False)
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)

image: Mapped["Image"] = relationship("Image", back_populates="labels")

__table_args__ = (UniqueConstraint("key", "image_id", name="uq_image_label_key"),)


class PatientLabels(Base):
"""Patient labels table."""
Expand All @@ -163,15 +166,14 @@ class PatientLabels(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
key = Column(String)
value = Column(String)
patient_id = Column(Integer, ForeignKey("patient.id"))

# Add a unique constraint
__table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),)
key = Column(String, nullable=False)
value = Column(String, nullable=False)
patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)

patient: Mapped["Patient"] = relationship("Patient", back_populates="labels")

__table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),)


class SplitDefinitions(Base):
"""Split definitions table."""
Expand All @@ -181,8 +183,9 @@ class SplitDefinitions(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
version = Column(String, nullable=False)
version = Column(String, nullable=False, unique=True)
description = Column(String)

splits: Mapped[List["Split"]] = relationship("Split", back_populates="split_definition")


Expand All @@ -196,9 +199,10 @@ class Split(Base):
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
category: Column[CategoryEnum] = Column(Enum(CategoryEnum), nullable=False)

patient_id = Column(Integer, ForeignKey("patient.id"))
patient: Mapped["Patient"] = relationship("Patient", back_populates="split")

patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)
split_definition_id = Column(Integer, ForeignKey("split_definitions.id"), nullable=False)

patient: Mapped["Patient"] = relationship("Patient", back_populates="splits")
split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits")

__table_args__ = (UniqueConstraint("split_definition_id", "patient_id", name="uq_patient_split_key"),)
10 changes: 6 additions & 4 deletions tools/populate_tcga_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def populate_from_annotated_tcga(
split_category = CategoryEnum.PREDICT
else:
split_category = random.choices(
[CategoryEnum.TRAIN, CategoryEnum.VALIDATE, CategoryEnum.TEST],
[CategoryEnum.FIT, CategoryEnum.VALIDATE, CategoryEnum.TEST],
[67, 33, 0],
)[0]

Expand Down Expand Up @@ -118,8 +118,10 @@ def populate_from_annotated_tcga(
image_annotation = ImageAnnotations(filename=str(annotation_path), reader="GEOJSON", image=image)
session.add(image_annotation)

label_data = "cancer" if random.choice([True, False]) else "benign" # Randomly decide if it's cancer or benign
image_label = ImageLabels(label_data=label_data, image=image)
# Randomly decide if it's cancer or benign
image_label = ImageLabels(
key="tumor_type", value="cancer" if random.choice([True, False]) else "benign", image=image
)
session.add(image_label)

session.commit()
Expand All @@ -129,5 +131,5 @@ def populate_from_annotated_tcga(
annotation_folder = Path("tissue_subtypes/v20230228_debug/")
image_folder = Path("/data/groups/aiforoncology/archive/pathology/TCGA/images/")
path_to_mapping = Path("/data/groups/aiforoncology/archive/pathology/TCGA/identifier_mapping.json")
with open_db("manifest.db") as session:
with open_db("sqlite:///manifest.db") as session:
populate_from_annotated_tcga(session, image_folder, annotation_folder, path_to_mapping, predict=True)