diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 5b5950a811..caae8cf08d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -333,6 +333,7 @@ def __init__( # There's a SQLAlchemy relationship declared, that takes precedence # over anything else, use that and continue with the next attribute dict_used[rel_name] = rel_info.sa_relationship + setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue ann = cls.__annotations__[rel_name] temp_field = ModelField.infer( diff --git a/tests/test_main.py b/tests/test_main.py index 22c62327da..72465cda33 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import List, Optional import pytest from sqlalchemy.exc import IntegrityError -from sqlmodel import Field, Session, SQLModel, create_engine +from sqlalchemy.orm import RelationshipProperty +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -91,3 +92,37 @@ class Hero(SQLModel, table=True): session.add(hero_2) session.commit() session.refresh(hero_2) + + +def test_sa_relationship_property(clear_sqlmodel): + """Test https://github.com/tiangolo/sqlmodel/issues/315#issuecomment-1272122306""" + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(unique=True) + heroes: List["Hero"] = Relationship( # noqa: F821 + sa_relationship=RelationshipProperty("Hero", back_populates="team") + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(unique=True) + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship( + sa_relationship=RelationshipProperty("Team", back_populates="heroes") + ) + + team_preventers = Team(name="Preventers") + hero_rusty_man = Hero(name="Rusty-Man", team=team_preventers) + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_rusty_man) + session.commit() + session.refresh(hero_rusty_man) + # The next statement should not raise an AttributeError + assert hero_rusty_man.team + assert hero_rusty_man.team.name == "Preventers"