Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Improve database models (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
VanessaBotha authored Nov 12, 2023
1 parent 30f3ecf commit e0e1ed4
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
2 changes: 1 addition & 1 deletion ahcore/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _batch_end(
tile_overlap = inference_grid.tile_overlap

# TODO: We are really putting strange things in the Queue if we may believe mypy
new_queue: Queue[Any] = Queue() # pylint: disable=unsubscriptable-object
new_queue: Queue[Any] = Queue() # pylint: disable=unsubscriptable-object
parent_conn, child_conn = Pipe()
new_writer = H5FileImageWriter(
output_filename,
Expand Down
60 changes: 32 additions & 28 deletions ahcore/utils/database_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, relationship


class CategoryEnum(PyEnum):
TRAIN = "fit"
class CategoryEnum(str, PyEnum):
FIT = "fit"
VALIDATE = "validate"
TEST = "test"
PREDICT = "predict"
Expand Down Expand Up @@ -44,7 +44,7 @@ class Patient(Base):
manifest: Mapped["Manifest"] = relationship("Manifest", back_populates="patients")
images: Mapped[List["Image"]] = relationship("Image", back_populates="patient")
labels: Mapped[List["PatientLabels"]] = relationship("PatientLabels", back_populates="patient")
split: Mapped[List["Split"]] = relationship("Split", uselist=False, back_populates="patient")
splits: Mapped[List["Split"]] = relationship("Split", back_populates="patient")


class Image(Base):
Expand All @@ -55,9 +55,9 @@ class Image(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
filename = Column(String, unique=True, nullable=False)
reader = Column(String)
patient_id = Column(Integer, ForeignKey("patient.id"))
patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)

height = Column(Integer)
width = Column(Integer)
Expand All @@ -66,8 +66,8 @@ class Image(Base):
patient: Mapped["Patient"] = relationship("Patient", back_populates="images")
masks: Mapped[List["Mask"]] = relationship("Mask", back_populates="image")
annotations: Mapped[List["ImageAnnotations"]] = relationship("ImageAnnotations", back_populates="image")
labels: Mapped["ImageLabels"] = relationship("ImageLabels", back_populates="image")
cache: Mapped["ImageCache"] = relationship("ImageCache", uselist=False, back_populates="image")
labels: Mapped[List["ImageLabels"]] = relationship("ImageLabels", back_populates="image")
caches: Mapped[List["ImageCache"]] = relationship("ImageCache", back_populates="image")


class ImageCache(Base):
Expand All @@ -79,14 +79,14 @@ class ImageCache(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
filename = Column(String, unique=True, nullable=False)
reader = Column(String)
num_tiles = Column(Integer)
image_id = Column(Integer, ForeignKey("image.id"))

image: Mapped["Image"] = relationship("Image", back_populates="cache")
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)
description_id = Column(Integer, ForeignKey("cache_description.id"))
description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="caches")

image: Mapped["Image"] = relationship("Image", back_populates="caches")
description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="cache")


class CacheDescription(Base):
Expand All @@ -108,7 +108,7 @@ class CacheDescription(Base):
mask_threshold = Column(Float)
grid_order = Column(String)

caches: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description")
cache: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description")


class Mask(Base):
Expand All @@ -121,7 +121,7 @@ class Mask(Base):
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
reader = Column(String)
image_id = Column(Integer, ForeignKey("image.id"))
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)

image: Mapped["Image"] = relationship("Image", back_populates="masks")

Expand All @@ -136,7 +136,7 @@ class ImageAnnotations(Base):
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
filename = Column(String, unique=True)
reader = Column(String)
image_id = Column(Integer, ForeignKey("image.id"))
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)

image: Mapped["Image"] = relationship("Image", back_populates="annotations")

Expand All @@ -149,11 +149,14 @@ class ImageLabels(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
label_data = Column(String) # e.g. "cancer" or "benign"
image_id = Column(Integer, ForeignKey("image.id"))
key = Column(String, nullable=False)
value = Column(String, nullable=False)
image_id = Column(Integer, ForeignKey("image.id"), nullable=False)

image: Mapped["Image"] = relationship("Image", back_populates="labels")

__table_args__ = (UniqueConstraint("key", "image_id", name="uq_image_label_key"),)


class PatientLabels(Base):
"""Patient labels table."""
Expand All @@ -163,15 +166,14 @@ class PatientLabels(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
key = Column(String)
value = Column(String)
patient_id = Column(Integer, ForeignKey("patient.id"))

# Add a unique constraint
__table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),)
key = Column(String, nullable=False)
value = Column(String, nullable=False)
patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)

patient: Mapped["Patient"] = relationship("Patient", back_populates="labels")

__table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),)


class SplitDefinitions(Base):
"""Split definitions table."""
Expand All @@ -181,8 +183,9 @@ class SplitDefinitions(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
version = Column(String, nullable=False)
version = Column(String, nullable=False, unique=True)
description = Column(String)

splits: Mapped[List["Split"]] = relationship("Split", back_populates="split_definition")


Expand All @@ -196,9 +199,10 @@ class Split(Base):
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
category: Column[CategoryEnum] = Column(Enum(CategoryEnum), nullable=False)

patient_id = Column(Integer, ForeignKey("patient.id"))
patient: Mapped["Patient"] = relationship("Patient", back_populates="split")

patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)
split_definition_id = Column(Integer, ForeignKey("split_definitions.id"), nullable=False)

patient: Mapped["Patient"] = relationship("Patient", back_populates="splits")
split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits")

__table_args__ = (UniqueConstraint("split_definition_id", "patient_id", name="uq_patient_split_key"),)
10 changes: 6 additions & 4 deletions tools/populate_tcga_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def populate_from_annotated_tcga(
split_category = CategoryEnum.PREDICT
else:
split_category = random.choices(
[CategoryEnum.TRAIN, CategoryEnum.VALIDATE, CategoryEnum.TEST],
[CategoryEnum.FIT, CategoryEnum.VALIDATE, CategoryEnum.TEST],
[67, 33, 0],
)[0]

Expand Down Expand Up @@ -118,8 +118,10 @@ def populate_from_annotated_tcga(
image_annotation = ImageAnnotations(filename=str(annotation_path), reader="GEOJSON", image=image)
session.add(image_annotation)

label_data = "cancer" if random.choice([True, False]) else "benign" # Randomly decide if it's cancer or benign
image_label = ImageLabels(label_data=label_data, image=image)
# Randomly decide if it's cancer or benign
image_label = ImageLabels(
key="tumor_type", value="cancer" if random.choice([True, False]) else "benign", image=image
)
session.add(image_label)

session.commit()
Expand All @@ -129,5 +131,5 @@ def populate_from_annotated_tcga(
annotation_folder = Path("tissue_subtypes/v20230228_debug/")
image_folder = Path("/data/groups/aiforoncology/archive/pathology/TCGA/images/")
path_to_mapping = Path("/data/groups/aiforoncology/archive/pathology/TCGA/identifier_mapping.json")
with open_db("manifest.db") as session:
with open_db("sqlite:///manifest.db") as session:
populate_from_annotated_tcga(session, image_folder, annotation_folder, path_to_mapping, predict=True)

0 comments on commit e0e1ed4

Please sign in to comment.