diff --git a/.editorconfig b/.editorconfig index d629bede5ec5..bf9021ff821d 100644 --- a/.editorconfig +++ b/.editorconfig @@ -4,7 +4,7 @@ root = true # 4 space indentation -[*.py] +[*.{py,pyi}] indent_style = space indent_size = 4 max_line_length = 88 diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.yml b/.github/ISSUE_TEMPLATE/BUG_REPORT.yml index 1b304198bc8f..abe0f656a28b 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.yml +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.yml @@ -74,6 +74,36 @@ body: - Debian packages from packages.matrix.org - pip (from PyPI) - Other (please mention below) + - I don't know + validations: + required: true + - type: input + id: database + attributes: + label: Database + description: | + Are you using SQLite or PostgreSQL? What's the version of your database? + + If PostgreSQL, please also answer the following: + - are you using a single PostgreSQL server + or [separate servers for `main` and `state`](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#databases)? + - have you previously ported from SQLite using the Synapse "portdb" script? + - have you previously restored from a backup? + validations: + required: true + - type: dropdown + id: workers + attributes: + label: Workers + description: | + Are you running a single Synapse process, or are you running + [2 or more workers](https://matrix-org.github.io/synapse/latest/workers.html)? + options: + - Single process + - Multiple workers + - I don't know + validations: + required: true - type: textarea id: platform attributes: @@ -83,17 +113,28 @@ body: e.g. distro, hardware, if it's running in a vm/container, etc. validations: required: true + - type: textarea + id: config + attributes: + label: Configuration + description: | + Do you have any unusual config options turned on? If so, please provide details. + + - Experimental or undocumented features + - [Presence](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#presence) + - [Message retention](https://matrix-org.github.io/synapse/latest/message_retention_policies.html) + - [Synapse modules](https://matrix-org.github.io/synapse/latest/modules/index.html) - type: textarea id: logs attributes: label: Relevant log output description: | Please copy and paste any relevant log output, ideally at INFO or DEBUG log level. - This will be automatically formatted into code, so there is no need for backticks. + This will be automatically formatted into code, so there is no need for backticks (`\``). Please be careful to remove any personal or private data. - **Bug reports are usually very difficult to diagnose without logging.** + **Bug reports are usually impossible to diagnose without logging.** render: shell validations: required: true diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index c6f481cdaace..a7097d5eaef6 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -27,7 +27,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -61,7 +61,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -134,7 +134,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: stable - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/push_complement_image.yml b/.github/workflows/push_complement_image.yml new file mode 100644 index 000000000000..f26143de6bbf --- /dev/null +++ b/.github/workflows/push_complement_image.yml @@ -0,0 +1,74 @@ +# This task does not run complement tests, see tests.yaml instead. +# This task does not build docker images for synapse for use on docker hub, see docker.yaml instead + +name: Store complement-synapse image in ghcr.io +on: + push: + branches: [ "master" ] + schedule: + - cron: '0 5 * * *' + workflow_dispatch: + inputs: + branch: + required: true + default: 'develop' + type: choice + options: + - develop + - master + +# Only run this action once per pull request/branch; restart if a new commit arrives. +# C.f. https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#concurrency +# and https://docs.github.com/en/actions/reference/context-and-expression-syntax-for-github-actions#github-context +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + name: Build and push complement image + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + steps: + - name: Checkout specific branch (debug build) + uses: actions/checkout@v3 + if: github.event_name == 'workflow_dispatch' + with: + ref: ${{ inputs.branch }} + - name: Checkout clean copy of develop (scheduled build) + uses: actions/checkout@v3 + if: github.event_name == 'schedule' + with: + ref: develop + - name: Checkout clean copy of master (on-push) + uses: actions/checkout@v3 + if: github.event_name == 'push' + with: + ref: master + - name: Login to registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Work out labels for complement image + id: meta + uses: docker/metadata-action@v4 + with: + images: ghcr.io/${{ github.repository }}/complement-synapse + tags: | + type=schedule,pattern=nightly,enable=${{ github.event_name == 'schedule'}} + type=raw,value=develop,enable=${{ github.event_name == 'schedule' || inputs.branch == 'develop' }} + type=raw,value=latest,enable=${{ github.event_name == 'push' || inputs.branch == 'master' }} + type=sha,format=long + - name: Run scripts-dev/complement.sh to generate complement-synapse:latest image. + run: scripts-dev/complement.sh --build-only + - name: Tag and push generated image + run: | + for TAG in ${{ join(fromJson(steps.meta.outputs.json).tags, ' ') }}; do + echo "tag and push $TAG" + docker tag complement-synapse $TAG + docker push $TAG + done diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ec5ab79f9c2b..b687eb002d88 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,6 +27,7 @@ jobs: rust: - 'rust/**' - 'Cargo.toml' + - 'Cargo.lock' check-sampleconfig: runs-on: ubuntu-latest @@ -102,7 +103,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 components: clippy @@ -122,7 +123,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 components: rustfmt @@ -184,7 +185,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -228,7 +229,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -346,7 +347,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -489,7 +490,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -517,7 +518,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 6a047193f67c..bbbe52d69753 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -43,7 +43,7 @@ jobs: - run: sudo apt-get -qq install xmlsec1 - name: Install Rust - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -82,7 +82,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@55c7845fad90d0ae8b2e83715cb900e5e861e8cb + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f with: toolchain: stable - uses: Swatinem/rust-cache@v2 diff --git a/CHANGES.md b/CHANGES.md index d1997f7379be..0238249218ae 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,83 @@ +Synapse 1.73.0 (2022-12-06) +=========================== + +Please note that legacy Prometheus metric names have been removed in this release; see [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.73/docs/upgrade.md#legacy-prometheus-metric-names-have-now-been-removed) for more details. + +No significant changes since 1.73.0rc2. + + +Synapse 1.73.0rc2 (2022-12-01) +============================== + +Bugfixes +-------- + +- Fix a regression in Synapse 1.73.0rc1 where Synapse's main process would stop responding to HTTP requests when a user with a large number of devices logs in. ([\#14582](https://github.com/matrix-org/synapse/issues/14582)) + + +Synapse 1.73.0rc1 (2022-11-29) +============================== + +Features +-------- + +- Speed-up `/messages` with `filter_events_for_client` optimizations. ([\#14527](https://github.com/matrix-org/synapse/issues/14527)) +- Improve DB performance by reducing amount of data that gets read in `device_lists_changes_in_room`. ([\#14534](https://github.com/matrix-org/synapse/issues/14534)) +- Adds support for handling avatar in SSO OIDC login. Contributed by @ashfame. ([\#13917](https://github.com/matrix-org/synapse/issues/13917)) +- Move MSC3030 `/timestamp_to_event` endpoints to stable `v1` location (`/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=`, `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=`). ([\#14471](https://github.com/matrix-org/synapse/issues/14471)) +- Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.5/client-server-api/#aggregations) which return bundled aggregations. ([\#14491](https://github.com/matrix-org/synapse/issues/14491), [\#14508](https://github.com/matrix-org/synapse/issues/14508), [\#14510](https://github.com/matrix-org/synapse/issues/14510)) +- Add unstable support for an Extensible Events room version (`org.matrix.msc1767.10`) via [MSC1767](https://github.com/matrix-org/matrix-spec-proposals/pull/1767), [MSC3931](https://github.com/matrix-org/matrix-spec-proposals/pull/3931), [MSC3932](https://github.com/matrix-org/matrix-spec-proposals/pull/3932), and [MSC3933](https://github.com/matrix-org/matrix-spec-proposals/pull/3933). ([\#14520](https://github.com/matrix-org/synapse/issues/14520), [\#14521](https://github.com/matrix-org/synapse/issues/14521), [\#14524](https://github.com/matrix-org/synapse/issues/14524)) +- Prune user's old devices on login if they have too many. ([\#14038](https://github.com/matrix-org/synapse/issues/14038), [\#14580](https://github.com/matrix-org/synapse/issues/14580)) + + +Bugfixes +-------- + +- Fix a long-standing bug where paginating from the start of a room did not work. Contributed by @gnunicorn. ([\#14149](https://github.com/matrix-org/synapse/issues/14149)) +- Fix a bug introduced in Synapse 1.58.0 where a user with presence state `org.matrix.msc3026.busy` would mistakenly be set to `online` when calling `/sync` or `/events` on a worker process. ([\#14393](https://github.com/matrix-org/synapse/issues/14393)) +- Fix a bug introduced in Synapse 1.70.0 where a receipt's thread ID was not sent over federation. ([\#14466](https://github.com/matrix-org/synapse/issues/14466)) +- Fix a long-standing bug where the [List media admin API](https://matrix-org.github.io/synapse/latest/admin_api/media_admin_api.html#list-all-media-in-a-room) would fail when processing an image with broken thumbnail information. ([\#14537](https://github.com/matrix-org/synapse/issues/14537)) +- Fix a bug introduced in Synapse 1.67.0 where two logging context warnings would be logged on startup. ([\#14574](https://github.com/matrix-org/synapse/issues/14574)) +- In application service transactions that include the experimental `org.matrix.msc3202.device_one_time_key_counts` key, include a duplicate key of `org.matrix.msc3202.device_one_time_keys_count` to match the name proposed by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/pull/3202). ([\#14565](https://github.com/matrix-org/synapse/issues/14565)) +- Fix a bug introduced in Synapse 0.9 where Synapse would fail to fetch server keys whose IDs contain a forward slash. ([\#14490](https://github.com/matrix-org/synapse/issues/14490)) + + +Improved Documentation +---------------------- + +- Fixed link to 'Synapse administration endpoints'. ([\#14499](https://github.com/matrix-org/synapse/issues/14499)) + + +Deprecations and Removals +------------------------- + +- Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. ([\#14538](https://github.com/matrix-org/synapse/issues/14538)) + + +Internal Changes +---------------- + +- Improve type hinting throughout Synapse. ([\#14055](https://github.com/matrix-org/synapse/issues/14055), [\#14412](https://github.com/matrix-org/synapse/issues/14412), [\#14529](https://github.com/matrix-org/synapse/issues/14529), [\#14452](https://github.com/matrix-org/synapse/issues/14452)). +- Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). ([\#14376](https://github.com/matrix-org/synapse/issues/14376), [\#14468](https://github.com/matrix-org/synapse/issues/14468)) +- Remove the `worker_main_http_uri` configuration setting. This is now handled via internal replication. ([\#14400](https://github.com/matrix-org/synapse/issues/14400), [\#14476](https://github.com/matrix-org/synapse/issues/14476)) +- Refactor `federation_sender` and `pusher` configuration loading. ([\#14496](https://github.com/matrix-org/synapse/issues/14496)) +([\#14509](https://github.com/matrix-org/synapse/issues/14509), [\#14573](https://github.com/matrix-org/synapse/issues/14573)) +- Faster joins: do not wait for full state when creating events to send. ([\#14403](https://github.com/matrix-org/synapse/issues/14403)) +- Faster joins: filter out non local events when a room doesn't have its full state. ([\#14404](https://github.com/matrix-org/synapse/issues/14404)) +- Faster joins: send events to initial list of servers if we don't have the full state yet. ([\#14408](https://github.com/matrix-org/synapse/issues/14408)) +- Faster joins: use servers list approximation received during `send_join` (potentially updated with received membership events) in `assert_host_in_room`. ([\#14515](https://github.com/matrix-org/synapse/issues/14515)) +- Fix type logic in TCP replication code that prevented correctly ignoring blank commands. ([\#14449](https://github.com/matrix-org/synapse/issues/14449)) +- Remove option to skip locking of tables when performing emulated upserts, to avoid a class of bugs in future. ([\#14469](https://github.com/matrix-org/synapse/issues/14469)) +- `scripts-dev/federation_client`: Fix routing on servers with `.well-known` files. ([\#14479](https://github.com/matrix-org/synapse/issues/14479)) +- Reduce default third party invite rate limit to 216 invites per day. ([\#14487](https://github.com/matrix-org/synapse/issues/14487)) +- Refactor conversion of device list changes in room to outbound pokes to track unconverted rows using a `(stream ID, room ID)` position instead of updating the `converted_to_destinations` flag on every row. ([\#14516](https://github.com/matrix-org/synapse/issues/14516)) +- Add more prompts to the bug report form. ([\#14522](https://github.com/matrix-org/synapse/issues/14522)) +- Extend editorconfig rules on indent and line length to `.pyi` files. ([\#14526](https://github.com/matrix-org/synapse/issues/14526)) +- Run Rust CI when `Cargo.lock` changes. This is particularly useful for dependabot updates. ([\#14571](https://github.com/matrix-org/synapse/issues/14571)) +- Fix a possible variable shadow in `create_new_client_event`. ([\#14575](https://github.com/matrix-org/synapse/issues/14575)) +- Bump various dependencies in the `poetry.lock` file and in CI scripts. ([\#14557](https://github.com/matrix-org/synapse/issues/14557), [\#14559](https://github.com/matrix-org/synapse/issues/14559), [\#14560](https://github.com/matrix-org/synapse/issues/14560), [\#14500](https://github.com/matrix-org/synapse/issues/14500), [\#14501](https://github.com/matrix-org/synapse/issues/14501), [\#14502](https://github.com/matrix-org/synapse/issues/14502), [\#14503](https://github.com/matrix-org/synapse/issues/14503), [\#14504](https://github.com/matrix-org/synapse/issues/14504), [\#14505](https://github.com/matrix-org/synapse/issues/14505)). + + Synapse 1.72.0 (2022-11-22) =========================== diff --git a/Cargo.lock b/Cargo.lock index 8a8099bc6d98..59d2aec21565 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -323,18 +323,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.147" +version = "1.0.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965" +checksum = "e53f64bb4ba0191d6d0676e1b141ca55047d83b74f5607e6d8eb88126c52c2dc" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.147" +version = "1.0.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852" +checksum = "a55492425aa53521babf6137309e7d34c20bbfbbfcfe2c7f3a047fd1f6b92c0c" dependencies = [ "proc-macro2", "quote", @@ -343,9 +343,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.87" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" +checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db" dependencies = [ "itoa", "ryu", @@ -366,9 +366,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.102" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" +checksum = "4ae548ec36cf198c0ef7710d3c230987c2d6d7bd98ad6edc0274462724c585ce" dependencies = [ "proc-macro2", "quote", diff --git a/debian/changelog b/debian/changelog index 1f1b4daa3151..163b7210bfde 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,21 @@ +matrix-synapse-py3 (1.73.0) stable; urgency=medium + + * New Synapse release 1.73.0. + + -- Synapse Packaging team Tue, 06 Dec 2022 11:48:56 +0000 + +matrix-synapse-py3 (1.73.0~rc2) stable; urgency=medium + + * New Synapse release 1.73.0rc2. + + -- Synapse Packaging team Thu, 01 Dec 2022 10:02:19 +0000 + +matrix-synapse-py3 (1.73.0~rc1) stable; urgency=medium + + * New Synapse release 1.73.0rc1. + + -- Synapse Packaging team Tue, 29 Nov 2022 12:28:13 +0000 + matrix-synapse-py3 (1.72.0) stable; urgency=medium * New Synapse release 1.72.0. diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 883a87159c16..ca640c343be7 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -100,8 +100,6 @@ experimental_features: # client-side support for partial state in /send_join responses faster_joins: true {% endif %} - # Enable jump to date endpoint - msc3030_enabled: true # Filtering /messages by relation type. msc3874_enabled: true diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 62b1bab297be..58c62f2231f3 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -140,6 +140,7 @@ "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event", "^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms", "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", + "^/_matrix/client/v1/rooms/.*/timestamp_to_event$", "^/_matrix/client/(api/v1|r0|v3|unstable)/search", ], "shared_extra_conf": {}, @@ -163,6 +164,7 @@ "^/_matrix/federation/(v1|v2)/invite/", "^/_matrix/federation/(v1|v2)/query_auth/", "^/_matrix/federation/(v1|v2)/event_auth/", + "^/_matrix/federation/v1/timestamp_to_event/", "^/_matrix/federation/(v1|v2)/exchange_third_party_invite/", "^/_matrix/federation/(v1|v2)/user/devices/", "^/_matrix/federation/(v1|v2)/get_groups_publicised$", @@ -213,10 +215,7 @@ "listener_resources": ["client", "replication"], "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"], "shared_extra_conf": {}, - "worker_extra_conf": ( - "worker_main_http_uri: http://127.0.0.1:%d" - % (MAIN_PROCESS_HTTP_LISTENER_PORT,) - ), + "worker_extra_conf": "", }, "account_data": { "app": "synapse.app.generic_worker", diff --git a/docs/upgrade.md b/docs/upgrade.md index 2aa353e4962d..4fe9e4f02e9c 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,28 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.73.0 + +## Legacy Prometheus metric names have now been removed + +Synapse v1.69.0 included the deprecation of legacy Prometheus metric names +and offered an option to disable them. +Synapse v1.71.0 disabled legacy Prometheus metric names by default. + +This version, v1.73.0, removes those legacy Prometheus metric names entirely. +This also means that the `enable_legacy_metrics` configuration option has been +removed; it will no longer be possible to re-enable the legacy metric names. + +If you use metrics and have not yet updated your Grafana dashboard(s), +Prometheus console(s) or alerting rule(s), please consider doing so when upgrading +to this version. +Note that the included Grafana dashboard was updated in v1.72.0 to correct some +metric names which were missed when legacy metrics were disabled by default. + +See [v1.69.0: Deprecation of legacy Prometheus metric names](#deprecation-of-legacy-prometheus-metric-names) +for more context. + + # Upgrading to v1.72.0 ## Dropping support for PostgreSQL 10 diff --git a/docs/usage/administration/admin_api/README.md b/docs/usage/administration/admin_api/README.md index f11e0b19a63a..c00de2dd447d 100644 --- a/docs/usage/administration/admin_api/README.md +++ b/docs/usage/administration/admin_api/README.md @@ -19,7 +19,7 @@ already on your `$PATH` depending on how Synapse was installed. Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings. ## Making an Admin API request -For security reasons, we [recommend](reverse_proxy.md#synapse-administration-endpoints) +For security reasons, we [recommend](../../../reverse_proxy.md#synapse-administration-endpoints) that the Admin API (`/_synapse/admin/...`) should be hidden from public view using a reverse proxy. This means you should typically query the Admin API from a terminal on the machine which runs Synapse. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index f5937dd902ff..749af12aac0b 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2437,31 +2437,6 @@ Example configuration: enable_metrics: true ``` --- -### `enable_legacy_metrics` - -Set to `true` to publish both legacy and non-legacy Prometheus metric names, -or to `false` to only publish non-legacy Prometheus metric names. -Defaults to `false`. Has no effect if `enable_metrics` is `false`. -**In Synapse v1.67.0 up to and including Synapse v1.70.1, this defaulted to `true`.** - -Legacy metric names include: -- metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; -- counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. - -These legacy metric names are unconventional and not compliant with OpenMetrics standards. -They are included for backwards compatibility. - -Example configuration: -```yaml -enable_legacy_metrics: false -``` - -See https://github.com/matrix-org/synapse/issues/11106 for context. - -*Since v1.67.0.* - -**Will be removed in v1.73.0.** ---- ### `sentry` Use this option to enable sentry integration. Provide the DSN assigned to you by sentry @@ -2993,10 +2968,17 @@ Options for each entry include: For the default provider, the following settings are available: - * subject_claim: name of the claim containing a unique identifier + * `subject_claim`: name of the claim containing a unique identifier for the user. Defaults to 'sub', which OpenID Connect compliant providers should provide. + * `picture_claim`: name of the claim containing an url for the user's profile picture. + Defaults to 'picture', which OpenID Connect compliant providers should provide + and has to refer to a direct image file such as PNG, JPEG, or GIF image file. + + Currently only supported in monolithic (single-process) server configurations + where the media repository runs within the Synapse process. + * `localpart_template`: Jinja2 template for the localpart of the MXID. If this is not set, the user will be prompted to choose their own username (see the documentation for the `sso_auth_account_details.html` diff --git a/docs/workers.md b/docs/workers.md index 7ee8801161cf..2b65acb5edac 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -135,8 +135,8 @@ In the config file for each worker, you must specify: [`worker_replication_http_port`](usage/configuration/config_documentation.md#worker_replication_http_port)). * If handling HTTP requests, a [`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option with an `http` listener. - * If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for - the main process (`worker_main_http_uri`). + * **Synapse 1.72 and older:** if handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for + the main process (`worker_main_http_uri`). This config option is no longer required and is ignored when running Synapse 1.73 and newer. For example: @@ -191,6 +191,7 @@ information. ^/_matrix/federation/(v1|v2)/send_leave/ ^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/v1/event_auth/ + ^/_matrix/federation/v1/timestamp_to_event/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ ^/_matrix/key/v2/query @@ -218,10 +219,10 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ + ^/_matrix/client/v1/rooms/.*/timestamp_to_event$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Encryption requests - # Note that ^/_matrix/client/(r0|v3|unstable)/keys/upload/ requires `worker_main_http_uri` ^/_matrix/client/(r0|v3|unstable)/keys/query$ ^/_matrix/client/(r0|v3|unstable)/keys/changes$ ^/_matrix/client/(r0|v3|unstable)/keys/claim$ @@ -376,7 +377,7 @@ responsible for - persisting them to the DB, and finally - updating the events stream. -Because load is sharded in this way, you *must* restart all worker instances when +Because load is sharded in this way, you *must* restart all worker instances when adding or removing event persisters. An `event_persister` should not be mistaken for an `event_creator`. diff --git a/mypy.ini b/mypy.ini index 8f1141a23905..0b6e7df26796 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,7 @@ warn_unused_ignores = True local_partial_types = True no_implicit_optional = True disallow_untyped_defs = True +strict_equality = True files = docker/, @@ -58,11 +59,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/test_state.py |tests/test_terms_auth.py - |tests/util/caches/test_cached_call.py - |tests/util/caches/test_deferred_cache.py - |tests/util/caches/test_descriptors.py - |tests/util/caches/test_response_cache.py - |tests/util/caches/test_ttlcache.py |tests/util/test_async_helpers.py |tests/util/test_batching_queue.py |tests/util/test_dict_cache.py @@ -117,9 +113,15 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True +[mypy-tests.storage.test_id_generators] +disallow_untyped_defs = True + [mypy-tests.storage.test_profile] disallow_untyped_defs = True +[mypy-tests.handlers.test_sso] +disallow_untyped_defs = True + [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True @@ -129,9 +131,14 @@ disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True -[mypy-tests.utils] +[mypy-tests.util.caches.*] disallow_untyped_defs = True +[mypy-tests.util.caches.test_descriptors] +disallow_untyped_defs = False + +[mypy-tests.utils] +disallow_untyped_defs = True ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. diff --git a/poetry.lock b/poetry.lock index 904ae1490aff..6772c3a0e5a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -671,7 +671,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" [[package]] name = "phonenumbers" -version = "8.12.56" +version = "8.13.0" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." category = "main" optional = false @@ -822,15 +822,15 @@ python-versions = ">=3.6" [[package]] name = "pygithub" -version = "1.56" +version = "1.57" description = "Use the full Github API v3" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] deprecated = "*" -pyjwt = ">=2.0" +pyjwt = ">=2.4.0" pynacl = ">=1.4.0" requests = ">=2.14.0" @@ -1084,7 +1084,7 @@ doc = ["Sphinx", "sphinx-rtd-theme"] [[package]] name = "sentry-sdk" -version = "1.10.1" +version = "1.11.0" description = "Python client for Sentry (https://sentry.io)" category = "main" optional = true @@ -1105,7 +1105,8 @@ falcon = ["falcon (>=1.4)"] fastapi = ["fastapi (>=0.79.0)"] flask = ["blinker (>=1.1)", "flask (>=0.11)"] httpx = ["httpx (>=0.16.0)"] -pure_eval = ["asttokens", "executing", "pure-eval"] +pure-eval = ["asttokens", "executing", "pure-eval"] +pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] rq = ["rq (>=0.6)"] @@ -1264,11 +1265,11 @@ python-versions = ">= 3.5" [[package]] name = "towncrier" -version = "21.9.0" +version = "22.8.0" description = "Building newsfiles for your project." category = "dev" optional = false -python-versions = "*" +python-versions = ">=3.7" [package.dependencies] click = "*" @@ -1276,7 +1277,7 @@ click-default-group = "*" incremental = "*" jinja2 = "*" setuptools = "*" -tomli = {version = "*", markers = "python_version >= \"3.6\""} +tomli = "*" [package.extras] dev = ["packaging"] @@ -1447,7 +1448,7 @@ python-versions = "*" [[package]] name = "types-pillow" -version = "9.2.2.1" +version = "9.3.0.1" description = "Typing stubs for Pillow" category = "dev" optional = false @@ -2327,8 +2328,8 @@ pathspec = [ {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, ] phonenumbers = [ - {file = "phonenumbers-8.12.56-py2.py3-none-any.whl", hash = "sha256:80a7422cf0999a6f9b7a2e6cfbdbbfcc56ab5b75414dc3b805bbec91276b64a3"}, - {file = "phonenumbers-8.12.56.tar.gz", hash = "sha256:82a4f226c930d02dcdf6d4b29e4cfd8678991fe65c2efd5fdd143557186f0868"}, + {file = "phonenumbers-8.13.0-py2.py3-none-any.whl", hash = "sha256:dbaea9e4005a976bcf18fbe2bb87cb9cd0a3f119136f04188ac412d7741cebf0"}, + {file = "phonenumbers-8.13.0.tar.gz", hash = "sha256:93745d7afd38e246660bb601b07deac54eeb76c8e5e43f5e83333b0383a0a1e4"}, ] pillow = [ {file = "Pillow-9.3.0-1-cp37-cp37m-win32.whl", hash = "sha256:e6ea6b856a74d560d9326c0f5895ef8050126acfdc7ca08ad703eb0081e82b74"}, @@ -2511,8 +2512,8 @@ pyflakes = [ {file = "pyflakes-2.5.0.tar.gz", hash = "sha256:491feb020dca48ccc562a8c0cbe8df07ee13078df59813b83959cbdada312ea3"}, ] pygithub = [ - {file = "PyGithub-1.56-py3-none-any.whl", hash = "sha256:d15f13d82165306da8a68aefc0f848a6f6432d5febbff13b60a94758ce3ef8b5"}, - {file = "PyGithub-1.56.tar.gz", hash = "sha256:80c6d85cf0f9418ffeb840fd105840af694c4f17e102970badbaf678251f2a01"}, + {file = "PyGithub-1.57-py3-none-any.whl", hash = "sha256:5822febeac2391f1306c55a99af2bc8f86c8bf82ded000030cd02c18f31b731f"}, + {file = "PyGithub-1.57.tar.gz", hash = "sha256:c273f252b278fb81f1769505cc6921bdb6791e1cebd6ac850cc97dad13c31ff3"}, ] pygments = [ {file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"}, @@ -2660,8 +2661,8 @@ semantic-version = [ {file = "semantic_version-2.10.0.tar.gz", hash = "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c"}, ] sentry-sdk = [ - {file = "sentry-sdk-1.10.1.tar.gz", hash = "sha256:105faf7bd7b7fa25653404619ee261527266b14103fe1389e0ce077bd23a9691"}, - {file = "sentry_sdk-1.10.1-py2.py3-none-any.whl", hash = "sha256:06c0fa9ccfdc80d7e3b5d2021978d6eb9351fa49db9b5847cf4d1f2a473414ad"}, + {file = "sentry-sdk-1.11.0.tar.gz", hash = "sha256:e7b78a1ddf97a5f715a50ab8c3f7a93f78b114c67307785ee828ef67a5d6f117"}, + {file = "sentry_sdk-1.11.0-py2.py3-none-any.whl", hash = "sha256:f467e6c7fac23d4d42bc83eb049c400f756cd2d65ab44f0cc1165d0c7c3d40bc"}, ] service-identity = [ {file = "service-identity-21.1.0.tar.gz", hash = "sha256:6e6c6086ca271dc11b033d17c3a8bea9f24ebff920c587da090afc9519419d34"}, @@ -2812,8 +2813,8 @@ tornado = [ {file = "tornado-6.1.tar.gz", hash = "sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791"}, ] towncrier = [ - {file = "towncrier-21.9.0-py2.py3-none-any.whl", hash = "sha256:fc5a88a2a54988e3a8ed2b60d553599da8330f65722cc607c839614ed87e0f92"}, - {file = "towncrier-21.9.0.tar.gz", hash = "sha256:9cb6f45c16e1a1eec9d0e7651165e7be60cd0ab81d13a5c96ca97a498ae87f48"}, + {file = "towncrier-22.8.0-py2.py3-none-any.whl", hash = "sha256:3b780c3d966e1b26414830aec3d15000654b31e64e024f3e5fd128b4c6eb8f47"}, + {file = "towncrier-22.8.0.tar.gz", hash = "sha256:7d3839b033859b45fb55df82b74cfd702431933c0cc9f287a5a7ea3e05d042cb"}, ] treq = [ {file = "treq-22.2.0-py3-none-any.whl", hash = "sha256:27d95b07c5c14be3e7b280416139b036087617ad5595be913b1f9b3ce981b9b2"}, @@ -2900,8 +2901,8 @@ types-opentracing = [ {file = "types_opentracing-2.4.10-py3-none-any.whl", hash = "sha256:66d9cfbbdc4a6f8ca8189a15ad26f0fe41cee84c07057759c5d194e2505b84c2"}, ] types-pillow = [ - {file = "types-Pillow-9.2.2.1.tar.gz", hash = "sha256:85c139e06e1c46ec5f9c634d5c54a156b0958d5d0e8be024ed353db0c804b426"}, - {file = "types_Pillow-9.2.2.1-py3-none-any.whl", hash = "sha256:3a6a871cade8428433a21ef459bb0a65532b87d05f9e836a0664431ce445bdcf"}, + {file = "types-Pillow-9.3.0.1.tar.gz", hash = "sha256:f3b7cada3fa496c78d75253c6b1f07a843d625f42e5639b320a72acaff6f7cfb"}, + {file = "types_Pillow-9.3.0.1-py3-none-any.whl", hash = "sha256:79837755fe9659f29efd1016e9903ac4a500e0c73260483f07296bd6ca47668b"}, ] types-psycopg2 = [ {file = "types-psycopg2-2.9.21.1.tar.gz", hash = "sha256:f5532cf15afdc6b5ebb1e59b7d896617217321f488fd1fbd74e7efb94decfab6"}, diff --git a/pyproject.toml b/pyproject.toml index 1882238c724c..f55acd99ffc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ manifest-path = "rust/Cargo.toml" [tool.poetry] name = "matrix-synapse" -version = "1.72.0" +version = "1.73.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "Apache-2.0" diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 5559ffd14c64..9cbb2828990f 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -365,6 +365,156 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.encrypted_room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.encrypted")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.message.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.message")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.file.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.file")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.image.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.image")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.video.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.video")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.audio.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.audio")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, @@ -393,6 +543,126 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.encrypted"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.encrypted")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.message"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.message")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.file"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.file")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.image"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.image")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.video"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.video")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.audio"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.audio")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.im.vector.jitsi"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0185b8f9b19d..219b03ab1b65 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Cow; use std::collections::BTreeMap; +use crate::push::{PushRule, PushRules}; use anyhow::{Context, Error}; use lazy_static::lazy_static; use log::warn; @@ -29,6 +31,33 @@ use super::{ lazy_static! { /// Used to parse the `is` clause in the room member count condition. static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex"); + + /// Used to determine which MSC3931 room version feature flags are actually known to + /// the push evaluator. + static ref KNOWN_RVER_FLAGS: Vec = vec![ + RoomVersionFeatures::ExtensibleEvents.as_str().to_string(), + ]; + + /// The "safe" rule IDs which are not affected by MSC3932's behaviour (room versions which + /// declare Extensible Events support ultimately *disable* push rules which do not declare + /// *any* MSC3931 room_version_supports condition). + static ref SAFE_EXTENSIBLE_EVENTS_RULE_IDS: Vec = vec![ + "global/override/.m.rule.master".to_string(), + "global/override/.m.rule.roomnotif".to_string(), + "global/content/.m.rule.contains_user_name".to_string(), + ]; +} + +enum RoomVersionFeatures { + ExtensibleEvents, +} + +impl RoomVersionFeatures { + fn as_str(&self) -> &'static str { + match self { + RoomVersionFeatures::ExtensibleEvents => "org.matrix.msc3932.extensible_events", + } + } } /// Allows running a set of push rules against a particular event. @@ -57,6 +86,13 @@ pub struct PushRuleEvaluator { /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, + + /// If MSC3931 is applicable, the feature flags for the room version. + room_version_feature_flags: Vec, + + /// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same + /// flag as MSC1767 (extensible events core). + msc3931_enabled: bool, } #[pymethods] @@ -70,6 +106,8 @@ impl PushRuleEvaluator { notification_power_levels: BTreeMap, related_events_flattened: BTreeMap>, related_event_match_enabled: bool, + room_version_feature_flags: Vec, + msc3931_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -84,6 +122,8 @@ impl PushRuleEvaluator { sender_power_level, related_events_flattened, related_event_match_enabled, + room_version_feature_flags, + msc3931_enabled, }) } @@ -106,7 +146,22 @@ impl PushRuleEvaluator { continue; } + let rule_id = &push_rule.rule_id().to_string(); + let extev_flag = &RoomVersionFeatures::ExtensibleEvents.as_str().to_string(); + let supports_extensible_events = self.room_version_feature_flags.contains(extev_flag); + let safe_from_rver_condition = SAFE_EXTENSIBLE_EVENTS_RULE_IDS.contains(rule_id); + let mut has_rver_condition = false; + for condition in push_rule.conditions.iter() { + has_rver_condition = has_rver_condition + || match condition { + Condition::Known(known) => match known { + // per MSC3932, we just need *any* room version condition to match + KnownCondition::RoomVersionSupports { feature: _ } => true, + _ => false, + }, + _ => false, + }; match self.match_condition(condition, user_id, display_name) { Ok(true) => {} Ok(false) => continue 'outer, @@ -117,6 +172,13 @@ impl PushRuleEvaluator { } } + // MSC3932: Disable push rules in extensible event-supporting room versions if they + // don't describe *any* MSC3931 room version condition, unless the rule is on the + // safe list. + if !has_rver_condition && !safe_from_rver_condition && supports_extensible_events { + continue; + } + let actions = push_rule .actions .iter() @@ -207,6 +269,15 @@ impl PushRuleEvaluator { false } } + KnownCondition::RoomVersionSupports { feature } => { + if !self.msc3931_enabled { + false + } else { + let flag = feature.to_string(); + KNOWN_RVER_FLAGS.contains(&flag) + && self.room_version_feature_flags.contains(&flag) + } + } }; Ok(result) @@ -365,9 +436,59 @@ fn push_rule_evaluator() { BTreeMap::new(), BTreeMap::new(), true, + vec![], + true, ) .unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); } + +#[test] +fn test_requires_room_version_supports_condition() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + false, + flags, + true, + ) + .unwrap(); + + // first test: are the master and contains_user_name rules excluded from the "requires room + // version condition" check? + let mut result = evaluator.run( + &FilteredPushRules::default(), + Some("@bob:example.org"), + None, + ); + assert_eq!(result.len(), 3); + + // second test: if an appropriate push rule is in play, does it get handled? + let custom_rule = PushRule { + rule_id: Cow::from("global/underride/.org.example.extensible"), + priority_class: 1, // underride + conditions: Cow::from(vec![Condition::Known( + KnownCondition::RoomVersionSupports { + feature: Cow::from(RoomVersionFeatures::ExtensibleEvents.as_str().to_string()), + }, + )]), + actions: Cow::from(vec![Action::Notify]), + default: false, + default_enabled: true, + }; + let rules = PushRules::new(vec![custom_rule]); + result = evaluator.run( + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, true), + None, + None, + ); + assert_eq!(result.len(), 1); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 4783645248e2..3dd32bd60709 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -279,6 +279,10 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, + #[serde(rename = "org.matrix.msc3931.room_version_supports")] + RoomVersionSupports { + feature: Cow<'static, str>, + }, } impl IntoPy for Condition { @@ -410,6 +414,7 @@ pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, msc3664_enabled: bool, + msc1767_enabled: bool, } #[pymethods] @@ -419,11 +424,13 @@ impl FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, msc3664_enabled: bool, + msc1767_enabled: bool, ) -> Self { Self { push_rules, enabled_map, msc3664_enabled, + msc1767_enabled, } } @@ -448,6 +455,10 @@ impl FilteredPushRules { return false; } + if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + return false; + } + true }) .map(|r| { @@ -493,6 +504,18 @@ fn test_deserialize_unstable_msc3664_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3931_condition() { + let json = + r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 803c6ce92d14..7744b47097b8 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -162,9 +162,9 @@ else # We only test faster room joins on monoliths, because they are purposefully # being developed without worker support to start with. # - # The tests for importing historical messages (MSC2716) and jump to date (MSC3030) - # also only pass with monoliths, currently. - test_tags="$test_tags,faster_joins,msc2716,msc3030" + # The tests for importing historical messages (MSC2716) also only pass with monoliths, + # currently. + test_tags="$test_tags,faster_joins,msc2716" fi diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 763dd02c477e..b1d5e2e61667 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -46,11 +46,12 @@ import signedjson.types import srvlookup import yaml +from requests import PreparedRequest, Response from requests.adapters import HTTPAdapter from urllib3 import HTTPConnectionPool # uncomment the following to enable debug logging of http requests -# from httplib import HTTPConnection +# from http.client import HTTPConnection # HTTPConnection.debuglevel = 1 @@ -103,6 +104,7 @@ def request( destination: str, path: str, content: Optional[str], + verify_tls: bool, ) -> requests.Response: if method is None: if content is None: @@ -141,7 +143,6 @@ def request( s.mount("matrix://", MatrixConnectionAdapter()) headers: Dict[str, str] = { - "Host": destination, "Authorization": authorization_headers[0], } @@ -152,7 +153,7 @@ def request( method=method, url=dest, headers=headers, - verify=False, + verify=verify_tls, data=content, stream=True, ) @@ -202,6 +203,12 @@ def main() -> None: parser.add_argument("--body", help="Data to send as the body of the HTTP request") + parser.add_argument( + "--insecure", + action="store_true", + help="Disable TLS certificate verification", + ) + parser.add_argument( "path", help="request path, including the '/_matrix/federation/...' prefix." ) @@ -227,6 +234,7 @@ def main() -> None: args.destination, args.path, content=args.body, + verify_tls=not args.insecure, ) sys.stderr.write("Status Code: %d\n" % (result.status_code,)) @@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None: class MatrixConnectionAdapter(HTTPAdapter): + def send( + self, + request: PreparedRequest, + *args: Any, + **kwargs: Any, + ) -> Response: + # overrides the send() method in the base class. + + # We need to look for .well-known redirects before passing the request up to + # HTTPAdapter.send(). + assert isinstance(request.url, str) + parsed = urlparse.urlsplit(request.url) + server_name = parsed.netloc + well_known = self._get_well_known(parsed.netloc) + + if well_known: + server_name = well_known + + # replace the scheme in the uri with https, so that cert verification is done + # also replace the hostname if we got a .well-known result + request.url = urlparse.urlunsplit( + ("https", server_name, parsed.path, parsed.query, parsed.fragment) + ) + + # at this point we also add the host header (otherwise urllib will add one + # based on the `host` from the connection returned by `get_connection`, + # which will be wrong if there is an SRV record). + request.headers["Host"] = server_name + + return super().send(request, *args, **kwargs) + + def get_connection( + self, url: str, proxies: Optional[Dict[str, str]] = None + ) -> HTTPConnectionPool: + # overrides the get_connection() method in the base class + parsed = urlparse.urlsplit(url) + (host, port, ssl_server_name) = self._lookup(parsed.netloc) + print( + f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr + ) + return self.poolmanager.connection_from_host( + host, + port=port, + scheme="https", + pool_kwargs={"server_hostname": ssl_server_name}, + ) + @staticmethod - def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]: - if s[-1] == "]": + def _lookup(server_name: str) -> Tuple[str, int, str]: + """ + Do an SRV lookup on a server name and return the host:port to connect to + Given the server_name (after any .well-known lookup), return the host, port and + the ssl server name + """ + if server_name[-1] == "]": # ipv6 literal (with no port) - return s, 8448 + return server_name, 8448, server_name - if ":" in s: - out = s.rsplit(":", 1) + if ":" in server_name: + # explicit port + out = server_name.rsplit(":", 1) try: port = int(out[1]) except ValueError: - raise ValueError("Invalid host:port '%s'" % s) - return out[0], port - - # try a .well-known lookup - if not skip_well_known: - well_known = MatrixConnectionAdapter.get_well_known(s) - if well_known: - return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True) + raise ValueError("Invalid host:port '%s'" % (server_name,)) + return out[0], port, out[0] try: - srv = srvlookup.lookup("matrix", "tcp", s)[0] - return srv.host, srv.port + srv = srvlookup.lookup("matrix", "tcp", server_name)[0] + print( + f"SRV lookup on _matrix._tcp.{server_name} gave {srv}", + file=sys.stderr, + ) + return srv.host, srv.port, server_name except Exception: - return s, 8448 + return server_name, 8448, server_name @staticmethod - def get_well_known(server_name: str) -> Optional[str]: - uri = "https://%s/.well-known/matrix/server" % (server_name,) - print("fetching %s" % (uri,), file=sys.stderr) + def _get_well_known(server_name: str) -> Optional[str]: + if ":" in server_name: + # explicit port, or ipv6 literal. Either way, no .well-known + return None + + # TODO: check for ipv4 literals + + uri = f"https://{server_name}/.well-known/matrix/server" + print(f"fetching {uri}", file=sys.stderr) try: resp = requests.get(uri) @@ -304,19 +369,6 @@ def get_well_known(server_name: str) -> Optional[str]: print("Invalid response from %s: %s" % (uri, e), file=sys.stderr) return None - def get_connection( - self, url: str, proxies: Optional[Dict[str, str]] = None - ) -> HTTPConnectionPool: - parsed = urlparse.urlparse(url) - - (host, port) = self.lookup(parsed.netloc) - netloc = "%s:%d" % (host, port) - print("Connecting to %s" % (netloc,), file=sys.stderr) - url = urlparse.urlunparse( - ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) - ) - return super().get_connection(url, proxies) - if __name__ == "__main__": main() diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index ceade65ef90e..a6a586a0b536 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -26,7 +26,11 @@ class PushRules: class FilteredPushRules: def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + self, + push_rules: PushRules, + enabled_map: Dict[str, bool], + msc3664_enabled: bool, + msc1767_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -41,6 +45,8 @@ class PushRuleEvaluator: notification_power_levels: Mapping[str, int], related_events_flattened: Mapping[str, Mapping[str, str]], related_event_match_enabled: bool, + room_version_feature_flags: list[str], + msc3931_enabled: bool, ): ... def run( self, diff --git a/synapse/api/errors.py b/synapse/api/errors.py index eb9fc85738d3..76ef12ed3a81 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -711,7 +711,7 @@ def to_synapse_error(self) -> SynapseError: set to the reason code from the HTTP response. Returns: - SynapseError: + The error converted to a SynapseError. """ # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index e37acb0f1edf..ac62011c9fb3 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional +from typing import Callable, Dict, List, Optional import attr @@ -51,6 +51,13 @@ class RoomDisposition: UNSTABLE = "unstable" +class PushRuleRoomFlag: + """Enum for listing possible MSC3931 room version feature flags, for push rules""" + + # MSC3932: Room version supports MSC1767 Extensible Events. + EXTENSIBLE_EVENTS = "org.matrix.msc3932.extensible_events" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class RoomVersion: """An object which describes the unique attributes of a room version.""" @@ -91,6 +98,12 @@ class RoomVersion: msc3787_knock_restricted_join_rule: bool # MSC3667: Enforce integer power levels msc3667_int_only_power_levels: bool + # MSC3931: Adds a push rule condition for "room version feature flags", making + # some push rules room version dependent. Note that adding a flag to this list + # is not enough to mark it "supported": the push rule evaluator also needs to + # support the flag. Unknown flags are ignored by the evaluator, making conditions + # fail if used. + msc3931_push_features: List[str] # values from PushRuleRoomFlag class RoomVersions: @@ -111,6 +124,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V2 = RoomVersion( "2", @@ -129,6 +143,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V3 = RoomVersion( "3", @@ -147,6 +162,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V4 = RoomVersion( "4", @@ -165,6 +181,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V5 = RoomVersion( "5", @@ -183,6 +200,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V6 = RoomVersion( "6", @@ -201,6 +219,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -219,6 +238,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V7 = RoomVersion( "7", @@ -237,6 +257,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V8 = RoomVersion( "8", @@ -255,6 +276,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V9 = RoomVersion( "9", @@ -273,6 +295,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) MSC3787 = RoomVersion( "org.matrix.msc3787", @@ -291,6 +314,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V10 = RoomVersion( "10", @@ -309,6 +333,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=True, + msc3931_push_features=[], ) MSC2716v4 = RoomVersion( "org.matrix.msc2716v4", @@ -327,6 +352,27 @@ class RoomVersions: msc2716_redactions=True, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], + ) + MSC1767v10 = RoomVersion( + # MSC1767 (Extensible Events) based on room version "10" + "org.matrix.msc1767.10", + RoomDisposition.UNSTABLE, + EventFormatVersions.ROOM_V4_PLUS, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc3375_redaction_rules=True, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, + msc3787_knock_restricted_join_rule=True, + msc3667_int_only_power_levels=True, + msc3931_push_features=[PushRuleRoomFlag.EXTENSIBLE_EVENTS], ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 41d2732ef96d..a5aa2185a28e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -266,26 +266,18 @@ async def wrapper() -> None: reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics( - bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool -) -> None: +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ from prometheus_client import start_http_server as start_http_server_prometheus - from synapse.metrics import ( - RegistryProxy, - start_http_server as start_http_server_legacy, - ) + from synapse.metrics import RegistryProxy for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - if enable_legacy_metric_names: - start_http_server_legacy(port, addr=host, registry=RegistryProxy) - else: - _set_prometheus_client_use_created_metrics(False) - start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + _set_prometheus_client_use_created_metrics(False) + start_http_server_prometheus(port, addr=host, registry=RegistryProxy) def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1d9aef45c24b..46dc73169643 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -14,14 +14,12 @@ # limitations under the License. import logging import sys -from typing import Dict, List, Optional, Tuple +from typing import Dict, List -from twisted.internet import address from twisted.web.resource import Resource import synapse import synapse.events -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.urls import ( CLIENT_API_PREFIX, FEDERATION_PREFIX, @@ -43,8 +41,6 @@ from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -70,12 +66,12 @@ versions, voip, ) -from synapse.rest.client._base import client_patterns from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet from synapse.rest.client.devices import DevicesRestServlet from synapse.rest.client.keys import ( KeyChangesServlet, KeyQueryServlet, + KeyUploadServlet, OneTimeKeyServlet, ) from synapse.rest.client.register import ( @@ -132,107 +128,12 @@ from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION from synapse.util.httpresourcetree import create_resource_tree logger = logging.getLogger("synapse.app.generic_worker") -class KeyUploadServlet(RestServlet): - """An implementation of the `KeyUploadServlet` that responds to read only - requests, but otherwise proxies through to the master instance. - """ - - PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - - def __init__(self, hs: HomeServer): - """ - Args: - hs: server - """ - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker.worker_main_http_uri - - async def on_POST( - self, request: SynapseRequest, device_id: Optional[str] - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - if body: - # They're actually trying to upload something, proxy to main synapse. - - # Proxy headers from the original request, such as the auth headers - # (in case the access token is there) and the original IP / - # User-Agent of the request. - headers: Dict[bytes, List[bytes]] = { - header: list(request.requestHeaders.getRawHeaders(header, [])) - for header in (b"Authorization", b"User-Agent") - } - # Add the previous hop to the X-Forwarded-For header. - x_forwarded_for = list( - request.requestHeaders.getRawHeaders(b"X-Forwarded-For", []) - ) - # we use request.client here, since we want the previous hop, not the - # original client (as returned by request.getClientAddress()). - if isinstance(request.client, (address.IPv4Address, address.IPv6Address)): - previous_host = request.client.host.encode("ascii") - # If the header exists, add to the comma-separated list of the first - # instance of the header. Otherwise, generate a new header. - if x_forwarded_for: - x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host] - x_forwarded_for.extend(x_forwarded_for[1:]) - else: - x_forwarded_for = [previous_host] - headers[b"X-Forwarded-For"] = x_forwarded_for - - # Replicate the original X-Forwarded-Proto header. Note that - # XForwardedForRequest overrides isSecure() to give us the original protocol - # used by the client, as opposed to the protocol used by our upstream proxy - # - which is what we want here. - headers[b"X-Forwarded-Proto"] = [ - b"https" if request.isSecure() else b"http" - ] - - try: - result = await self.http_client.post_json_get_json( - self.main_uri + request.uri.decode("ascii"), body, headers=headers - ) - except HttpResponseException as e: - raise e.to_synapse_error() from e - except RequestSendFailed as e: - raise SynapseError(502, "Failed to talk to master") from e - - return 200, result - else: - # Just interested in counts. - result = await self.store.count_e2e_one_time_keys(user_id, device_id) - return 200, {"one_time_key_counts": result} - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. @@ -419,7 +320,6 @@ def start_listening(self) -> None: _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: logger.warning("Unsupported listener type: %s", listener.type) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 4f4fee4782f2..b9be558c7ea0 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -265,7 +265,6 @@ def start_listening(self) -> None: _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: # this shouldn't happen, as the listener type should have been checked diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 05382718f3d4..5c8a72c56f50 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -32,9 +32,9 @@ logger = logging.getLogger(__name__) -# Type for the `device_one_time_key_counts` field in an appservice transaction +# Type for the `device_one_time_keys_count` field in an appservice transaction # user ID -> {device ID -> {algorithm -> count}} -TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] +TransactionOneTimeKeysCount = Dict[str, Dict[str, Dict[str, int]]] # Type for the `device_unused_fallback_key_types` field in an appservice transaction # user ID -> {device ID -> [algorithm]} @@ -381,7 +381,7 @@ def __init__( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ): @@ -390,7 +390,7 @@ def __init__( self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages - self.one_time_key_counts = one_time_key_counts + self.one_time_keys_count = one_time_keys_count self.unused_fallback_keys = unused_fallback_keys self.device_list_summary = device_list_summary @@ -407,7 +407,7 @@ async def send(self, as_api: "ApplicationServiceApi") -> bool: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, - one_time_key_counts=self.one_time_key_counts, + one_time_keys_count=self.one_time_keys_count, unused_fallback_keys=self.unused_fallback_keys, device_list_summary=self.device_list_summary, txn_id=self.id, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 60774b240d9f..edafd433cda3 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -23,7 +23,7 @@ from synapse.api.errors import CodeMessageException from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.events import EventBase @@ -262,7 +262,7 @@ async def push_bulk( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, @@ -310,10 +310,13 @@ async def push_bulk( # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: - if one_time_key_counts: + if one_time_keys_count: body[ "org.matrix.msc3202.device_one_time_key_counts" - ] = one_time_key_counts + ] = one_time_keys_count + body[ + "org.matrix.msc3202.device_one_time_keys_count" + ] = one_time_keys_count if unused_fallback_keys: body[ "org.matrix.msc3202.device_unused_fallback_key_types" diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index e52c3d2ba6b2..e11a312e6bc8 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -64,7 +64,7 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.appservice.api import ApplicationServiceApi @@ -258,7 +258,7 @@ async def _send_request(self, service: ApplicationService) -> None: ): return - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None if ( @@ -269,7 +269,7 @@ async def _send_request(self, service: ApplicationService) -> None: # for the users which are mentioned in this transaction, # as well as the appservice's sender. ( - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, ) = await self._compute_msc3202_otk_counts_and_fallback_keys( service, events, ephemeral, to_device_messages_to_send @@ -281,7 +281,7 @@ async def _send_request(self, service: ApplicationService) -> None: events, ephemeral, to_device_messages_to_send, - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, device_list_summary, ) @@ -296,7 +296,7 @@ async def _compute_msc3202_otk_counts_and_fallback_keys( events: Iterable[EventBase], ephemerals: Iterable[JsonDict], to_device_messages: Iterable[JsonDict], - ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, - first computes a list of application services users that may have @@ -367,7 +367,7 @@ async def send( events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: @@ -380,7 +380,7 @@ async def send( events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -397,7 +397,7 @@ async def send( events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], - one_time_key_counts=one_time_key_counts or {}, + one_time_keys_count=one_time_keys_count or {}, unused_fallback_keys=unused_fallback_keys or {}, device_list_summary=device_list_summary or DeviceListUpdates(), ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c056f8176423..13f212880b0a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -16,6 +16,7 @@ import attr +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config._base import Config from synapse.types import JsonDict @@ -57,9 +58,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) - # MSC3030 (Jump to date API endpoint) - self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. @@ -135,3 +133,10 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC3912: Relation-based redactions. self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) + + # MSC1767 and friends: Extensible Events + self.msc1767_enabled: bool = experimental.get("msc1767_enabled", False) + if self.msc1767_enabled: + # Enable room version (and thus applicable push rules from MSC3931/3932) + version_id = RoomVersions.MSC1767v10.identifier + KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 94d115041526..5468b963a2c1 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -317,10 +317,9 @@ def setup_logging( Set up the logging subsystem. Args: - config (LoggingConfig | synapse.config.worker.WorkerConfig): - configuration data + config: configuration data - use_worker_options (bool): True to use the 'worker_log_config' option + use_worker_options: True to use the 'worker_log_config' option instead of 'log_config'. logBeginner: The Twisted logBeginner to use. diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 6034a0346e58..8c1c9bd12d45 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -43,8 +43,6 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) - self.enable_legacy_metrics = config.get("enable_legacy_metrics", False) - self.report_stats = config.get("report_stats", None) self.report_stats_endpoint = config.get( "report_stats_endpoint", "https://matrix.org/report-usage-stats/push" diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 1ed001e10553..5c13fe428a70 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -150,8 +150,5 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.rc_third_party_invite = RatelimitSettings( config.get("rc_third_party_invite", {}), - defaults={ - "per_second": self.rc_message.per_second, - "burst_count": self.rc_message.burst_count, - }, + defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 88b3168cbc71..2580660b6c27 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -29,20 +29,6 @@ ) from .server import DIRECT_TCP_ERROR, ListenerConfig, parse_listener_def -_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """ -The send_federation config option must be disabled in the main -synapse process before they can be run in a separate worker. - -Please add ``send_federation: false`` to the main config -""" - -_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """ -The start_pushers config option must be disabled in the main -synapse process before they can be run in a separate worker. - -Please add ``start_pushers: false`` to the main config -""" - _DEPRECATED_WORKER_DUTY_OPTION_USED = """ The '%s' configuration option is deprecated and will be removed in a future Synapse version. Please use ``%s: name_of_worker`` instead. @@ -162,7 +148,13 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.worker_name = config.get("worker_name", self.worker_app) self.instance_name = self.worker_name or "master" + # FIXME: Remove this check after a suitable amount of time. self.worker_main_http_uri = config.get("worker_main_http_uri", None) + if self.worker_main_http_uri is not None: + logger.warning( + "The config option worker_main_http_uri is unused since Synapse 1.73. " + "It can be safely removed from your configuration." + ) # This option is really only here to support `--manhole` command line # argument. @@ -176,40 +168,12 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: ) ) - # Handle federation sender configuration. - # - # There are two ways of configuring which instances handle federation - # sending: - # 1. The old way where "send_federation" is set to false and running a - # `synapse.app.federation_sender` worker app. - # 2. Specifying the workers sending federation in - # `federation_sender_instances`. - # - - send_federation = config.get("send_federation", True) - - federation_sender_instances = config.get("federation_sender_instances") - if federation_sender_instances is None: - # Default to an empty list, which means "another, unknown, worker is - # responsible for it". - federation_sender_instances = [] - - # If no federation sender instances are set we check if - # `send_federation` is set, which means use master - if send_federation: - federation_sender_instances = ["master"] - - if self.worker_app == "synapse.app.federation_sender": - if send_federation: - # If we're running federation senders, and not using - # `federation_sender_instances`, then we should have - # explicitly set `send_federation` to false. - raise ConfigError( - _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR - ) - - federation_sender_instances = [self.worker_name] - + federation_sender_instances = self._worker_names_performing_this_duty( + config, + "send_federation", + "synapse.app.federation_sender", + "federation_sender_instances", + ) self.send_federation = self.instance_name in federation_sender_instances self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances @@ -276,27 +240,12 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: ) # Handle sharded push - start_pushers = config.get("start_pushers", True) - pusher_instances = config.get("pusher_instances") - if pusher_instances is None: - # Default to an empty list, which means "another, unknown, worker is - # responsible for it". - pusher_instances = [] - - # If no pushers instances are set we check if `start_pushers` is - # set, which means use master - if start_pushers: - pusher_instances = ["master"] - - if self.worker_app == "synapse.app.pusher": - if start_pushers: - # If we're running pushers, and not using - # `pusher_instances`, then we should have explicitly set - # `start_pushers` to false. - raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR) - - pusher_instances = [self.instance_name] - + pusher_instances = self._worker_names_performing_this_duty( + config, + "start_pushers", + "synapse.app.pusher", + "pusher_instances", + ) self.start_pushers = self.instance_name in pusher_instances self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) @@ -419,6 +368,64 @@ def _should_this_worker_perform_duty( # (By this point, these are either the same value or only one is not None.) return bool(new_option_should_run_here or legacy_option_should_run_here) + def _worker_names_performing_this_duty( + self, + config: Dict[str, Any], + legacy_option_name: str, + legacy_app_name: str, + modern_instance_list_name: str, + ) -> List[str]: + """ + Retrieves the names of the workers handling a given duty, by either legacy + option or instance list. + + There are two ways of configuring which instances handle a given duty, e.g. + for configuring pushers: + + 1. The old way where "start_pushers" is set to false and running a + `synapse.app.pusher'` worker app. + 2. Specifying the workers sending federation in `pusher_instances`. + + Args: + config: settings read from yaml. + legacy_option_name: the old way of enabling options. e.g. 'start_pushers' + legacy_app_name: The historical app name. e.g. 'synapse.app.pusher' + modern_instance_list_name: the string name of the new instance_list. e.g. + 'pusher_instances' + + Returns: + A list of worker instance names handling the given duty. + """ + + legacy_option = config.get(legacy_option_name, True) + + worker_instances = config.get(modern_instance_list_name) + if worker_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + worker_instances = [] + + # If no worker instances are set we check if the legacy option + # is set, which means use the main process. + if legacy_option: + worker_instances = ["master"] + + if self.worker_app == legacy_app_name: + if legacy_option: + # If we're using `legacy_app_name`, and not using + # `modern_instance_list_name`, then we should have + # explicitly set `legacy_option_name` to false. + raise ConfigError( + f"The '{legacy_option_name}' config option must be disabled in " + "the main synapse process before they can be run in a separate " + "worker.\n" + f"Please add `{legacy_option_name}: false` to the main config.\n", + ) + + worker_instances = [self.worker_name] + + return worker_instances + def read_arguments(self, args: argparse.Namespace) -> None: # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c88afb298620..ed15f88350a6 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -213,7 +213,7 @@ async def verify_json_for_server( def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int]] - ) -> List[defer.Deferred]: + ) -> List["defer.Deferred[None]"]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. @@ -226,10 +226,9 @@ def verify_json_objects_for_server( valid. Returns: - List: for each input triplet, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. + For each input triplet, a deferred indicating success or failure to + verify each json object's signature for the given server_name. The + deferreds run their callbacks in the sentinel logcontext. """ return [ run_in_background( @@ -858,7 +857,7 @@ async def get_server_verify_key_v2_direct( response = await self.client.get_json( destination=server_name, path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), + + urllib.parse.quote(requested_key_id, safe=""), ignore_backoff=True, # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 030c3ca408c0..8aca9a3ab9e9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -597,8 +597,7 @@ def _event_type_from_format_version( format_version: The event format version Returns: - type: A type that can be initialized as per the initializer of - `FrozenEvent` + A type that can be initialized as per the initializer of `FrozenEvent` """ if format_version == EventFormatVersions.ROOM_V1_V2: diff --git a/synapse/events/builder.py b/synapse/events/builder.py index e2ee10dd3ddc..d62906043f3f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,6 +128,7 @@ async def build( state_filter=StateFilter.from_types( auth_types_for_event(self.room_version, self) ), + await_full_state=False, ) auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c4c0bc7315b4..8bccc9c60d5f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1691,9 +1691,19 @@ async def _timestamp_to_event_from_destination( # to return events on *both* sides of the timestamp to # help reconcile the gap faster. _timestamp_to_event_from_destination, + # Since this endpoint is new, we should try other servers before giving up. + # We can safely remove this in a year (remove after 2023-11-16). + failover_on_unknown_endpoint=True, ) return timestamp_to_event_response - except SynapseError: + except SynapseError as e: + logger.warn( + "timestamp_to_event(room_id=%s, timestamp=%s, direction=%s): encountered error when trying to fetch from destinations: %s", + room_id, + timestamp, + direction, + e, + ) return None async def _timestamp_to_event_from_destination( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 431e68995b22..05d865fe990f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -441,7 +441,23 @@ async def handle_event(event: EventBase) -> None: # If there are no prev event IDs then the state is empty # and so no remote servers in the room destinations = set() - else: + + if destinations is None: + # During partial join we use the set of servers that we got + # when beginning the join. It's still possible that we send + # events to servers that left the room in the meantime, but + # we consider that an acceptable risk since it is only our own + # events that we leak and not other server's ones. + partial_state_destinations = ( + await self.store.get_partial_state_servers_at_join( + event.room_id + ) + ) + + if len(partial_state_destinations) > 0: + destinations = partial_state_destinations + + if destinations is None: # We check the external cache for the destinations, which is # stored per state group. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 084c45a95ca1..5af2784f1e98 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -35,7 +35,7 @@ from synapse.logging.opentracing import SynapseTags, set_tag from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt +from synapse.types import JsonDict, ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.visibility import filter_events_for_server @@ -136,8 +136,11 @@ def __init__( # destination self._pending_presence: Dict[str, UserPresenceState] = {} - # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {} + # List of room_id -> receipt_type -> user_id -> receipt_dict, + # + # Each receipt can only have a single receipt per + # (room ID, receipt type, user ID, thread ID) tuple. + self._pending_receipt_edus: List[Dict[str, Dict[str, Dict[str, dict]]]] = [] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -202,17 +205,53 @@ def queue_read_receipt(self, receipt: ReadReceipt) -> None: Args: receipt: receipt to be queued """ - self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( - receipt.receipt_type, {} - )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} + serialized_receipt: JsonDict = { + "event_ids": receipt.event_ids, + "data": receipt.data, + } + if receipt.thread_id is not None: + serialized_receipt["data"]["thread_id"] = receipt.thread_id + + # Find which EDU to add this receipt to. There's three situations depending + # on the (room ID, receipt type, user, thread ID) tuple: + # + # 1. If it fully matches, clobber the information. + # 2. If it is missing, add the information. + # 3. If the subset tuple of (room ID, receipt type, user) matches, check + # the next EDU (or add a new EDU). + for edu in self._pending_receipt_edus: + receipt_content = edu.setdefault(receipt.room_id, {}).setdefault( + receipt.receipt_type, {} + ) + # If this room ID, receipt type, user ID is not in this EDU, OR if + # the full tuple matches, use the current EDU. + if ( + receipt.user_id not in receipt_content + or receipt_content[receipt.user_id].get("thread_id") + == receipt.thread_id + ): + receipt_content[receipt.user_id] = serialized_receipt + break + + # If no matching EDU was found, create a new one. + else: + self._pending_receipt_edus.append( + { + receipt.room_id: { + receipt.receipt_type: {receipt.user_id: serialized_receipt} + } + } + ) def flush_read_receipts_for_room(self, room_id: str) -> None: - # if we don't have any read-receipts for this room, it may be that we've already - # sent them out, so we don't need to flush. - if room_id not in self._pending_rrs: - return - self._rrs_pending_flush = True - self.attempt_new_transaction() + # If there are any pending receipts for this room then force-flush them + # in a new transaction. + for edu in self._pending_receipt_edus: + if room_id in edu: + self._rrs_pending_flush = True + self.attempt_new_transaction() + # No use in checking remaining EDUs if the room was found. + break def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu @@ -351,7 +390,7 @@ async def _transaction_transmission_loop(self) -> None: self._pending_edus = [] self._pending_edus_keyed = {} self._pending_presence = {} - self._pending_rrs = {} + self._pending_receipt_edus = [] self._start_catching_up() except FederationDeniedError as e: @@ -505,6 +544,7 @@ async def _catch_up_transmission_loop(self) -> None: new_pdus = await filter_events_for_server( self._storage_controllers, self._destination, + self._server_name, new_pdus, redact=False, ) @@ -542,22 +582,27 @@ async def _catch_up_transmission_loop(self) -> None: self._destination, last_successful_stream_ordering ) - def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: - if not self._pending_rrs: + def _get_receipt_edus(self, force_flush: bool, limit: int) -> Iterable[Edu]: + if not self._pending_receipt_edus: return if not force_flush and not self._rrs_pending_flush: # not yet time for this lot return - edu = Edu( - origin=self._server_name, - destination=self._destination, - edu_type=EduTypes.RECEIPT, - content=self._pending_rrs, - ) - self._pending_rrs = {} - self._rrs_pending_flush = False - yield edu + # Send at most limit EDUs for receipts. + for content in self._pending_receipt_edus[:limit]: + yield Edu( + origin=self._server_name, + destination=self._destination, + edu_type=EduTypes.RECEIPT, + content=content, + ) + self._pending_receipt_edus = self._pending_receipt_edus[limit:] + + # If there are still pending read-receipts, don't reset the pending flush + # flag. + if not self._pending_receipt_edus: + self._rrs_pending_flush = False def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus @@ -644,27 +689,61 @@ class _TransactionQueueManager: async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: # First we calculate the EDUs we want to send, if any. - # We start by fetching device related EDUs, i.e device updates and to - # device messages. We have to keep 2 free slots for presence and rr_edus. - device_edu_limit = MAX_EDUS_PER_TRANSACTION - 2 + # There's a maximum number of EDUs that can be sent with a transaction, + # generally device updates and to-device messages get priority, but we + # want to ensure that there's room for some other EDUs as well. + # + # This is done by: + # + # * Add a presence EDU, if one exists. + # * Add up-to a small limit of read receipt EDUs. + # * Add to-device EDUs, but leave some space for device list updates. + # * Add device list updates EDUs. + # * If there's any remaining room, add other EDUs. + pending_edus = [] + + # Add presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type=EduTypes.PRESENCE, + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} - # We prioritize to-device messages so that existing encryption channels + # Add read receipt EDUs. + pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5)) + edu_limit = MAX_EDUS_PER_TRANSACTION - len(pending_edus) + + # Next, prioritize to-device messages so that existing encryption channels # work. We also keep a few slots spare (by reducing the limit) so that # we can still trickle out some device list updates. ( to_device_edus, device_stream_id, - ) = await self.queue._get_to_device_message_edus(device_edu_limit - 10) + ) = await self.queue._get_to_device_message_edus(edu_limit - 10) if to_device_edus: self._device_stream_id = device_stream_id else: self.queue._last_device_stream_id = device_stream_id - device_edu_limit -= len(to_device_edus) + pending_edus.extend(to_device_edus) + edu_limit -= len(to_device_edus) + # Add device list update EDUs. device_update_edus, dev_list_id = await self.queue._get_device_update_edus( - device_edu_limit + edu_limit ) if device_update_edus: @@ -672,40 +751,17 @@ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: else: self.queue._last_device_list_stream_id = dev_list_id - pending_edus = device_update_edus + to_device_edus - - # Now add the read receipt EDU. - pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) - - # And presence EDU. - if self.queue._pending_presence: - pending_edus.append( - Edu( - origin=self.queue._server_name, - destination=self.queue._destination, - edu_type=EduTypes.PRESENCE, - content={ - "push": [ - format_user_presence_state( - presence, self.queue._clock.time_msec() - ) - for presence in self.queue._pending_presence.values() - ] - }, - ) - ) - self.queue._pending_presence = {} + pending_edus.extend(device_update_edus) + edu_limit -= len(device_update_edus) # Finally add any other types of EDUs if there is room. - pending_edus.extend( - self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self.queue._pending_edus_keyed - ): + other_edus = self.queue._pop_pending_edus(edu_limit) + pending_edus.extend(other_edus) + edu_limit -= len(other_edus) + while edu_limit > 0 and self.queue._pending_edus_keyed: _, val = self.queue._pending_edus_keyed.popitem() pending_edus.append(val) + edu_limit -= 1 # Now we look for any PDUs to send, by getting up to 50 PDUs from the # queue @@ -716,8 +772,10 @@ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: # if we've decided to send a transaction anyway, and we have room, we # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + if edu_limit: + pending_edus.extend( + self.queue._get_receipt_edus(force_flush=True, limit=edu_limit) + ) if self._pdus: self._last_stream_ordering = self._pdus[ diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd39d4d1113a..77f1f39cacb1 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -185,9 +185,8 @@ async def timestamp_to_event( Raises: Various exceptions when the request fails """ - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, - "/org.matrix.msc3030/timestamp_to_event/%s", + path = _create_v1_path( + "/timestamp_to_event/%s", room_id, ) @@ -280,12 +279,11 @@ async def make_membership_event( Note that this does not append any events to any graphs. Args: - destination (str): address of remote homeserver - room_id (str): room to join/leave - user_id (str): user to be joined/left - membership (str): one of join/leave - params (dict[str, str|Iterable[str]]): Query parameters to include in the - request. + destination: address of remote homeserver + room_id: room to join/leave + user_id: user to be joined/left + membership: one of join/leave + params: Query parameters to include in the request. Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 50623cd38513..2725f53cf6d9 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -25,7 +25,6 @@ from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, - FederationTimestampLookupServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -291,13 +290,6 @@ def register_servlets( ) for servletclass in SERVLET_GROUPS[servlet_group]: - # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled - if ( - servletclass == FederationTimestampLookupServlet - and not hs.config.experimental.msc3030_enabled - ): - continue - # Only allow the `/account_status` servlet if msc3720 is enabled if ( servletclass == FederationAccountStatusServlet diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 1db8009d6ccf..cdaf0d5de782 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -224,10 +224,10 @@ class BaseFederationServlet: With arguments: - origin (unicode|None): The authenticated server_name of the calling server, + origin (str|None): The authenticated server_name of the calling server, unless REQUIRE_AUTH is set to False and authentication failed. - content (unicode|None): decoded json body of the request. None if the + content (str|None): decoded json body of the request. None if the request was a GET. query (dict[bytes, list[bytes]]): Query params from the request. url-decoded diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 205fd16daa98..53e77b4bb62b 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -218,14 +218,13 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= + GET /_matrix/federation/v1/timestamp_to_event/?ts=&dir= { "event_id": ... } """ PATH = "/timestamp_to_event/(?P[^/]*)/?" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" async def on_GET( self, diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79c8..d74d135c0c50 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ async def deactivate_account( True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c597639a7ff5..b1e55e1b9e4a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ async def store_dehydrated_device( await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -682,13 +688,33 @@ async def _handle_new_device_update_async(self) -> None: hosts_already_sent_to: Set[str] = set() try: + stream_id, room_id = await self.store.get_device_change_last_converted_pos() + while True: self._handle_new_device_update_new_data = False - rows = await self.store.get_uncoverted_outbound_room_pokes() + max_stream_id = self.store.get_device_stream_token() + rows = await self.store.get_uncoverted_outbound_room_pokes( + stream_id, room_id + ) if not rows: # If the DB returned nothing then there is nothing left to # do, *unless* a new device list update happened during the # DB query. + + # Advance `(stream_id, room_id)`. + # `max_stream_id` comes from *before* the query for unconverted + # rows, which means that any unconverted rows must have a larger + # stream ID. + if max_stream_id > stream_id: + stream_id, room_id = max_stream_id, "" + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + else: + assert max_stream_id == stream_id + # Avoid moving `room_id` backwards. + pass + if self._handle_new_device_update_new_data: continue else: @@ -718,7 +744,6 @@ async def _handle_new_device_update_async(self) -> None: user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=stream_id, hosts=hosts, context=opentracing_context, ) @@ -752,6 +777,12 @@ async def _handle_new_device_update_async(self) -> None: hosts_already_sent_to.update(hosts) current_stream_id = stream_id + # Advance `(stream_id, room_id)`. + _, _, room_id, stream_id, _ = rows[-1] + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + finally: self._handle_new_device_update_is_processing = False @@ -834,7 +865,6 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None: user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=None, hosts=potentially_changed_hosts, context=None, ) @@ -858,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a9912c467dbb..5fe102e2f2f3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ async def _query_devices_for_destination( # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ async def claim_client_keys(destination: str) -> None: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ async def upload_signing_keys_for_user( user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ async def upload_signatures_for_device_keys( Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -870,7 +872,7 @@ async def _process_self_signatures( - signatures of the user's master key by the user's devices. Args: - user_id (string): the user uploading the keys + user_id: the user uploading the keys signatures (dict[string, dict]): map of devices to signed keys Returns: @@ -1200,6 +1202,9 @@ async def _retrieve_cross_signing_keys_for_remote_user( A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ async def _handle_signing_key_updates(self, user_id: str) -> None: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ async def _handle_signing_key_updates(self, user_id: str) -> None: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 28dc08c22a36..83f53ceb8891 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -377,8 +377,9 @@ async def delete_version(self, user_id: str, version: Optional[str] = None) -> N """Deletes a given version of the user's e2e_room_keys backup Args: - user_id(str): the user whose current backup version we're deleting - version(str): the version id of the backup being deleted + user_id: the user whose current backup version we're deleting + version: Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. Raises: NotFoundError: if this backup version doesn't exist """ diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 3bbad0271bcc..f91dbbecb79c 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -45,6 +45,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main + self._state_storage_controller = hs.get_storage_controllers().state self._server_name = hs.hostname async def check_auth_rules_from_context( @@ -179,17 +180,22 @@ async def assert_host_in_room( this function may return an incorrect result as we are not able to fully track server membership in a room without full state. """ - if not allow_partial_state_rooms and await self._store.is_partial_state_room( - room_id - ): - raise AuthError( - 403, - "Unable to authorise you right now; room is partial-stated here.", - errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, - ) - - if not await self.is_host_in_room(room_id, host): - raise AuthError(403, "Host not in room.") + if await self._store.is_partial_state_room(room_id): + if allow_partial_state_rooms: + current_hosts = await self._state_storage_controller.get_current_hosts_in_room_or_partial_state_approximation( + room_id + ) + if host not in current_hosts: + raise AuthError(403, "Host not in room (partial-state approx).") + else: + raise AuthError( + 403, + "Unable to authorise you right now; room is partial-stated here.", + errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, + ) + else: + if not await self.is_host_in_room(room_id, host): + raise AuthError(403, "Host not in room.") async def check_restricted_join_rules( self, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5fc3b8bc8c3d..d92582fd5c77 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -379,6 +379,7 @@ async def _maybe_backfill_inner( filtered_extremities = await filter_events_for_server( self._storage_controllers, self.server_name, + self.server_name, events_to_check, redact=False, check_history_visibility_only=True, @@ -1231,7 +1232,9 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Synapse asks for 100 events per backfill request. Do not allow more. limit = min(limit, 100) @@ -1252,7 +1255,7 @@ async def on_backfill_request( ) events = await filter_events_for_server( - self._storage_controllers, origin, events + self._storage_controllers, origin, self.server_name, events ) return events @@ -1283,7 +1286,7 @@ async def get_persisted_pdu( await self._event_auth_handler.assert_host_in_room(event.room_id, origin) events = await filter_events_for_server( - self._storage_controllers, origin, [event] + self._storage_controllers, origin, self.server_name, [event] ) event = events[0] return event @@ -1296,7 +1299,9 @@ async def on_get_missing_events( latest_events: List[str], limit: int, ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) @@ -1309,7 +1314,7 @@ async def on_get_missing_events( ) missing_events = await filter_events_for_server( - self._storage_controllers, origin, missing_events + self._storage_controllers, origin, self.server_name, missing_events ) return missing_events @@ -1596,8 +1601,8 @@ async def get_room_complexity( Fetch the complexity of a remote room over federation. Args: - remote_room_hosts (list[str]): The remote servers to ask. - room_id (str): The room ID to ask about. + remote_room_hosts: The remote servers to ask. + room_id: The room ID to ask about. Returns: Dict contains the complexity diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 93d09e993961..848e46eb9ba6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -711,7 +711,7 @@ async def ask_id_server_for_third_party_invite( inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. - id_access_token (str): The access token to authenticate to the identity + id_access_token: The access token to authenticate to the identity server with Returns: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ec039f5efa3d..4bcdde0e7590 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1137,11 +1137,13 @@ async def create_new_client_event( ) state_events = await self.store.get_events_as_list(state_event_ids) # Create a StateMap[str] - state_map = {(e.type, e.state_key): e.event_id for e in state_events} + current_state_ids = { + (e.type, e.state_key): e.event_id for e in state_events + } # Actually strip down and only use the necessary auth events auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, - current_state_ids=state_map, + current_state_ids=current_state_ids, for_verification=False, ) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 867973dcca4a..03de6a4ba637 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -787,7 +787,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo: Must include an ``access_token`` field. Returns: - UserInfo: an object representing the user. + an object representing the user. """ logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() @@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict): localpart: Optional[str] confirm_localpart: bool display_name: Optional[str] + picture: Optional[str] # may be omitted by older `OidcMappingProviders` emails: List[str] @@ -1520,6 +1521,7 @@ def jinja_finalize(thing: Any) -> Any: @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: subject_claim: str + picture_claim: str localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1539,6 +1541,7 @@ def __init__(self, config: JinjaOidcMappingConfig): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") + picture_claim = config.get("picture_claim", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1572,6 +1575,7 @@ def parse_template_config(option_name: str) -> Optional[Template]: return JinjaOidcMappingConfig( subject_claim=subject_claim, + picture_claim=picture_claim, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1611,10 +1615,13 @@ def render_template_field(template: Optional[Template]) -> Optional[str]: if email: emails.append(email) + picture = userinfo.get("picture") + return UserAttributeDict( localpart=localpart, display_name=display_name, emails=emails, + picture=picture, confirm_localpart=self._config.confirm_localpart, ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 44e29311536e..aac5ef6612f0 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -460,6 +460,12 @@ async def get_messages( if pagin_config.from_token: from_token = pagin_config.from_token + elif pagin_config.direction == "f": + from_token = ( + await self.hs.get_event_sources().get_start_token_for_pagination( + room_id + ) + ) else: from_token = ( await self.hs.get_event_sources().get_current_token_for_pagination( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0066d63987b7..cf08737d115a 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -201,7 +201,7 @@ async def current_state_for_users( """Get the current presence state for multiple users. Returns: - dict: `user_id` -> `UserPresenceState` + A mapping of `user_id` -> `UserPresenceState` """ states = {} missing = [] @@ -478,7 +478,7 @@ async def user_syncing( return _NullContextManager() prev_state = await self.current_state_for_user(user_id) - if prev_state != PresenceState.BUSY: + if prev_state.state != PresenceState.BUSY: # We set state here but pass ignore_status_msg = True as we don't want to # cause the status message to be cleared. # Note that this causes last_active_ts to be incremented which is not diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 2e49716c23f8..57c7ca2f1bc2 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -92,7 +92,6 @@ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None continue # Check if these receipts apply to a thread. - thread_id = None data = user_values.get("data", {}) thread_id = data.get("thread_id") # If the thread ID is invalid, consider it missing. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a18667e..6307fa9c5d66 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class and RegisterDeviceReplicationServlet. refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e71dda970c0..e96f9999a8d6 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,17 +13,19 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional import attr from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, UserID +from synapse.util.async_helpers import gather_results from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -172,40 +174,6 @@ async def get_relations( return return_value - async def get_relations_for_event( - self, - event_id: str, - event: EventBase, - room_id: str, - relation_type: str, - ignored_users: FrozenSet[str] = frozenset(), - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - """Get a list of events which relate to an event, ordered by topological ordering. - - Args: - event_id: Fetch events that relate to this event ID. - event: The matching EventBase to event_id. - room_id: The room the event belongs to. - relation_type: The type of relation. - ignored_users: The users ignored by the requesting user. - - Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. - """ - - # Call the underlying storage method, which is cached. - related_events, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, relation_type, direction="f" - ) - - # Filter out ignored users and convert to the expected format. - related_events = [ - event for event in related_events if event.sender not in ignored_users - ] - - return related_events, next_token - async def redact_events_related_to( self, requester: Requester, @@ -259,51 +227,107 @@ async def redact_events_related_to( e.msg, ) - async def get_annotations_for_event( - self, - event_id: str, - room_id: str, - limit: int = 5, - ignored_users: FrozenSet[str] = frozenset(), - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + async def get_annotations_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[JsonDict]]: + """Get a list of annotations to the given events, grouped by event type and aggregation key, sorted by count. - This is used e.g. to get the what and how many reactions have happend + This is used e.g. to get the what and how many reactions have happened on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. ignored_users: The users ignored by the requesting user. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_event( - event_id, room_id, limit + full_results = await self._main_store.get_aggregation_groups_for_events( + event_ids ) + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in full_results.items() + if results + } + # Then subtract off the results for any ignored users. ignored_results = await self._main_store.get_aggregation_groups_for_users( - event_id, room_id, limit, ignored_users + [event_id for event_id, results in full_results.items() if results], + ignored_users, ) - filtered_results = [] - for result in full_results: - key = (result["type"], result["key"]) - if key in ignored_results: - result = result.copy() - result["count"] -= ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.append(result) + filtered_results = {} + for event_id, results in full_results.items(): + # If no annotations, skip. + if not results: + continue + + # If there are not ignored results for this event, copy verbatim. + if event_id not in ignored_results: + filtered_results[event_id] = results + continue + + # Otherwise, subtract out the ignored results. + event_ignored_results = ignored_results[event_id] + for result in results: + key = (result["type"], result["key"]) + if key in event_ignored_results: + # Ensure to not modify the cache. + result = result.copy() + result["count"] -= event_ignored_results[key] + if result["count"] <= 0: + continue + filtered_results.setdefault(event_id, []).append(result) return filtered_results + async def get_references_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[_RelatedEvent]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to this event ID. + ignored_users: The users ignored by the requesting user. + + Returns: + A map of event IDs to a list related events. + """ + + related_events = await self._main_store.get_references_for_events(event_ids) + + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in related_events.items() + if results + } + + # Filter out ignored users. + results = {} + for event_id, events in related_events.items(): + # If no references, skip. + if not events: + continue + + # Filter ignored users out. + events = [event for event in events if event.sender not in ignored_users] + # If there are no events left, skip this event. + if not events: + continue + + results[event_id] = events + + return results + async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], @@ -366,59 +390,66 @@ async def _get_threads_for_events( results = {} for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event = summary - - # Subtract off the count of any ignored users. - for ignored_user in ignored_users: - thread_count -= ignored_results.get((event_id, ignored_user), 0) - - # This is gnarly, but if the latest event is from an ignored user, - # attempt to find one that isn't from an ignored user. - if latest_thread_event.sender in ignored_users: - room_id = latest_thread_event.room_id - - # If the root event is not found, something went wrong, do - # not include a summary of the thread. - event = await self._event_handler.get_event(user, room_id, event_id) - if event is None: - continue + # If no thread, skip. + if not summary: + continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, - ) + thread_count, latest_thread_event = summary - # If all found events are from ignored users, do not include - # a summary of the thread. - if not potential_events: - continue + # Subtract off the count of any ignored users. + for ignored_user in ignored_users: + thread_count -= ignored_results.get((event_id, ignored_user), 0) - # The *last* event returned is the one that is cared about. - event = await self._event_handler.get_event( - user, room_id, potential_events[-1].event_id - ) - # It is unexpected that the event will not exist. - if event is None: - logger.warning( - "Unable to fetch latest event in a thread with event ID: %s", - potential_events[-1].event_id, - ) - continue - latest_thread_event = event - - results[event_id] = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=events_by_id[event_id].sender == user_id - or participated[event_id], + # This is gnarly, but if the latest event is from an ignored user, + # attempt to find one that isn't from an ignored user. + if latest_thread_event.sender in ignored_users: + room_id = latest_thread_event.room_id + + # If the root event is not found, something went wrong, do + # not include a summary of the thread. + event = await self._event_handler.get_event(user, room_id, event_id) + if event is None: + continue + + # Attempt to find another event to use as the latest event. + potential_events, _ = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.THREAD, direction="f" ) + # Filter out ignored users. + potential_events = [ + event + for event in potential_events + if event.sender not in ignored_users + ] + + # If all found events are from ignored users, do not include + # a summary of the thread. + if not potential_events: + continue + + # The *last* event returned is the one that is cared about. + event = await self._event_handler.get_event( + user, room_id, potential_events[-1].event_id + ) + # It is unexpected that the event will not exist. + if event is None: + logger.warning( + "Unable to fetch latest event in a thread with event ID: %s", + potential_events[-1].event_id, + ) + continue + latest_thread_event = event + + results[event_id] = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], + ) + return results @trace @@ -496,49 +527,56 @@ async def get_bundled_aggregations( # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any annotations (ie, reactions) to bundle with this event. - annotations = await self.get_annotations_for_event( - event.event_id, event.room_id, ignored_users=ignored_users + async def _fetch_annotations() -> None: + """Fetch any annotations (ie, reactions) to bundle with this event.""" + annotations_by_event_id = await self.get_annotations_for_events( + events_by_id.keys(), ignored_users=ignored_users ) - if annotations: - results.setdefault( - event.event_id, BundledAggregations() - ).annotations = {"chunk": annotations} - - # Fetch any references to bundle with this event. - references, next_token = await self.get_relations_for_event( - event.event_id, - event, - event.room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, + for event_id, annotations in annotations_by_event_id.items(): + if annotations: + results.setdefault(event_id, BundledAggregations()).annotations = { + "chunk": annotations + } + + async def _fetch_references() -> None: + """Fetch any references to bundle with this event.""" + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): + if references: + results.setdefault(event_id, BundledAggregations()).references = { + "chunk": [{"event_id": ev.event_id} for ev in references] + } + + async def _fetch_edits() -> None: + """ + Fetch any edits (but not for redacted events). + + Note that there is no use in limiting edits by ignored users since the + parent event should be ignored in the first place if the user is ignored. + """ + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Parallelize the calls for annotations, references, and edits since they + # are unrelated. + await make_deferred_yieldable( + gather_results( + ( + run_in_background(_fetch_annotations), + run_in_background(_fetch_references), + run_in_background(_fetch_edits), + ) ) - if references: - aggregations = results.setdefault(event.event_id, BundledAggregations()) - aggregations.references = { - "chunk": [{"event_id": ev.event_id} for ev in references] - } - - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - - # Fetch any edits (but not for redacted events). - # - # Note that there is no use in limiting edits by ignored users since the - # parent event should be ignored in the first place if the user is ignored. - edits = await self._main_store.get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit return results @@ -571,7 +609,7 @@ async def get_threads( room_id, requester, allow_departed_users=True ) - # Note that ignored users are not passed into get_relations_for_event + # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). thread_roots, next_batch = await self._main_store.get_threads( diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 9602f0d0bb48..874860d461e0 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -441,7 +441,7 @@ def saml_response_to_user_attributes( client_redirect_url: where the client wants to redirect to Returns: - dict: A dict containing new user attributes. Possible keys: + A dict containing new user attributes. Possible keys: * mxid_localpart (str): Required. The localpart of the user's mxid * displayname (str): The displayname of the user * emails (list[str]): Any emails for the user @@ -483,7 +483,7 @@ def parse_config(config: dict) -> SamlConfig: Args: config: A dictionary containing configuration options for this provider Returns: - SamlConfig: A custom config object for this module + A custom config object for this module """ # Parse config options and use defaults where necessary mxid_source_attribute = config.get("mxid_source_attribute", "uid") diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd4085..bd9d0bb34b1a 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0f6..44e70fc4b874 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import hashlib +import io import logging from typing import ( TYPE_CHECKING, @@ -37,6 +39,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -137,6 +140,7 @@ class UserAttributes: localpart: Optional[str] confirm_localpart: bool = False display_name: Optional[str] = None + picture: Optional[str] = None emails: Collection[str] = attr.Factory(list) @@ -195,6 +199,10 @@ def __init__(self, hs: "HomeServer"): self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() + self._media_repo = ( + hs.get_media_repository() if hs.config.media.can_load_media_repo else None + ) + self._http_client = hs.get_proxied_blacklisted_http_client() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. @@ -494,6 +502,8 @@ async def complete_sso_login_request( await self._profile_handler.set_displayname( user_id_obj, requester, attributes.display_name, True ) + if attributes.picture: + await self.set_avatar(user_id, attributes.picture) await self._auth_handler.complete_sso_login( user_id, @@ -702,8 +712,110 @@ async def _register_mapped_user( await self._store.record_user_external_id( auth_provider_id, remote_user_id, registered_user_id ) + + # Set avatar, if available + if attributes.picture: + await self.set_avatar(registered_user_id, attributes.picture) + return registered_user_id + async def set_avatar(self, user_id: str, picture_https_url: str) -> bool: + """Set avatar of the user. + + This downloads the image file from the URL provided, stores that in + the media repository and then sets the avatar on the user's profile. + + It can detect if the same image is being saved again and bails early by storing + the hash of the file in the `upload_name` of the avatar image. + + Currently, it only supports server configurations which run the media repository + within the same process. + + It silently fails and logs a warning by raising an exception and catching it + internally if: + * it is unable to fetch the image itself (non 200 status code) or + * the image supplied is bigger than max allowed size or + * the image type is not one of the allowed image types. + + Args: + user_id: matrix user ID in the form @localpart:domain as a string. + + picture_https_url: HTTPS url for the picture image file. + + Returns: `True` if the user's avatar has been successfully set to the image at + `picture_https_url`. + """ + if self._media_repo is None: + logger.info( + "failed to set user avatar because out-of-process media repositories " + "are not supported yet " + ) + return False + + try: + uid = UserID.from_string(user_id) + + def is_allowed_mime_type(content_type: str) -> bool: + if ( + self._profile_handler.allowed_avatar_mimetypes + and content_type + not in self._profile_handler.allowed_avatar_mimetypes + ): + return False + return True + + # download picture, enforcing size limit & mime type check + picture = io.BytesIO() + + content_length, headers, uri, code = await self._http_client.get_file( + url=picture_https_url, + output_stream=picture, + max_size=self._profile_handler.max_avatar_size, + is_allowed_content_type=is_allowed_mime_type, + ) + + if code != 200: + raise Exception( + "GET request to download sso avatar image returned {}".format(code) + ) + + # upload name includes hash of the image file's content so that we can + # easily check if it requires an update or not, the next time user logs in + upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest() + + # bail if user already has the same avatar + profile = await self._profile_handler.get_profile(user_id) + if profile["avatar_url"] is not None: + server_name = profile["avatar_url"].split("/")[-2] + media_id = profile["avatar_url"].split("/")[-1] + if server_name == self._server_name: + media = await self._media_repo.store.get_local_media(media_id) + if media is not None and upload_name == media["upload_name"]: + logger.info("skipping saving the user avatar") + return True + + # store it in media repository + avatar_mxc_url = await self._media_repo.create_content( + media_type=headers[b"Content-Type"][0].decode("utf-8"), + upload_name=upload_name, + content=picture, + content_length=content_length, + auth_user=uid, + ) + + # save it as user avatar + await self._profile_handler.set_avatar_url( + uid, + create_requester(uid), + str(avatar_mxc_url), + ) + + logger.info("successfully saved the user avatar") + return True + except Exception: + logger.warning("failed to save the user avatar") + return False + async def complete_sso_ui_auth_request( self, auth_provider_id: str, @@ -1035,6 +1147,8 @@ async def revoke_sessions_for_provider_session_id( ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1156,12 @@ async def revoke_sessions_for_provider_session_id( expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 77d868fc2369..db4395311b80 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1467,14 +1467,14 @@ async def generate_sync_result( logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts: JsonDict = {} + one_time_keys_count: JsonDict = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: # * no change in OTK count since the provided since token # * the server has zero OTKs left for this device # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 - one_time_key_counts = await self.store.count_e2e_one_time_keys( + one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) unused_fallback_key_types = ( @@ -1504,7 +1504,7 @@ async def generate_sync_result( archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, - device_one_time_keys_count=one_time_key_counts, + device_one_time_keys_count=one_time_keys_count, device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 6a9f6635d2c0..8729630581b5 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -45,8 +45,7 @@ def __init__( Args: hs: homeserver - handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): - function to be called to handle the request. + handler: function to be called to handle the request. """ super().__init__() self._handler = handler diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 2f0177f1e203..0359231e7dd3 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -155,11 +155,10 @@ def request( a file for a file upload). Or None if the request is to have no body. Returns: - Deferred[twisted.web.iweb.IResponse]: - fires when the header of the response has been received (regardless of the - response status code). Fails if there is any problem which prevents that - response from being received (including problems that prevent the request - from being sent). + A deferred which fires when the header of the response has been received + (regardless of the response status code). Fails if there is any problem + which prevents that response from being received (including problems that + prevent the request from being sent). """ # We use urlparse as that will set `port` to None if there is no # explicit port. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3c35b1d2c7af..b92f1d3d1af5 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -951,8 +951,7 @@ async def post_json( args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The - result will be the decoded JSON body. + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: HttpResponseException: If we get an HTTP response code >= 300 diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 1f8227896f65..18899bc6d18d 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -34,7 +34,7 @@ ) from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS +from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials @@ -134,7 +134,7 @@ def request( uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. @@ -157,17 +157,17 @@ def request( a file upload). Or, None if the request is to have no body. Returns: - Deferred[IResponse]: completes when the header of the response has - been received (regardless of the response status code). + A deferred which completes when the header of the response has + been received (regardless of the response status code). - Can fail with: - SchemeNotSupported: if the uri is not http or https + Can fail with: + SchemeNotSupported: if the uri is not http or https - twisted.internet.error.TimeoutError if the server we are connecting - to (proxy or destination) does not accept a connection before - connectTimeout. + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. - ... other things too. + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/http/server.py b/synapse/http/server.py index 65eec0481e4d..2563858f3cdf 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -267,7 +267,7 @@ def register_paths( request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. This should return either tuple of (code, response), or None. - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ diff --git a/synapse/http/site.py b/synapse/http/site.py index ac6680f681ec..3a4d91f60b8a 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -401,7 +401,7 @@ def _started_processing(self, servlet_name: str) -> None: be sure to call finished_processing. Args: - servlet_name (str): the name of the servlet which will be + servlet_name: the name of the servlet which will be processing this request. This is used in the metrics. It is possible to update this afterwards by updating diff --git a/synapse/logging/context.py b/synapse/logging/context.py index e2f4a1c80ba5..28213a1ac140 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -117,8 +117,7 @@ def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None: """Create a new ContextResourceUsage Args: - copy_from (ContextResourceUsage|None): if not None, an object to - copy stats from + copy_from: if not None, an object to copy stats from """ if copy_from is None: self.reset() @@ -162,7 +161,7 @@ def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": """Add another ContextResourceUsage's stats to this one's. Args: - other (ContextResourceUsage): the other resource usage object + other: the other resource usage object """ self.ru_utime += other.ru_utime self.ru_stime += other.ru_stime @@ -343,7 +342,7 @@ def current_context(cls) -> LoggingContextOrSentinel: called directly. Returns: - LoggingContext: the current logging context + The current logging context """ warnings.warn( "synapse.logging.context.LoggingContext.current_context() is deprecated " @@ -363,7 +362,8 @@ def set_current_context( called directly. Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -475,8 +475,7 @@ def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. Returns: - ContextResourceUsage: a *copy* of the object tracking resource - usage so far + A *copy* of the object tracking resource usage so far """ # we always return a copy, for consistency res = self._resource_usage.copy() @@ -665,7 +664,8 @@ def current_context() -> LoggingContextOrSentinel: def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: """Set the current logging context in thread local storage Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -702,7 +702,7 @@ def nested_logging_context(suffix: str) -> LoggingContext: suffix: suffix to add to the parent context's 'name'. Returns: - LoggingContext: new logging context. + A new logging context. """ curr_context = current_context() if not curr_context: @@ -900,20 +900,19 @@ def defer_to_thread( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked, and whose threadpool we should use for the - function. + reactor: The reactor in whose main thread the Deferred will be invoked, + and whose threadpool we should use for the function. Normally this will be hs.get_reactor(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) @@ -941,20 +940,20 @@ def defer_to_threadpool( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked. Normally this will be hs.get_reactor(). + reactor: The reactor in whose main thread the Deferred will be invoked. + Normally this will be hs.get_reactor(). - threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for - running `f`. Normally this will be hs.get_reactor().getThreadPool(). + threadpool: The threadpool to use for running `f`. Normally this will be + hs.get_reactor().getThreadPool(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ curr_context = current_context() diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ce5a2a33818..b69060854f07 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -721,7 +721,7 @@ def inject_header_dict( destination: address of entity receiving the span context. Must be given unless check_destination is False. The context will only be injected if the destination matches the opentracing whitelist - check_destination (bool): If false, destination will be ignored and the context + check_destination: If false, destination will be ignored and the context will always be injected. Note: @@ -780,7 +780,7 @@ def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str destination: the name of the remote server. Returns: - dict: the active span's context if opentracing is enabled, otherwise empty. + the active span's context if opentracing is enabled, otherwise empty. """ if destination and not whitelisted_homeserver(destination): diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index c3d3daf8774c..b01372565d14 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -47,11 +47,7 @@ # This module is imported for its side effects; flake8 needn't warn that it's unused. import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager -from synapse.metrics._legacy_exposition import ( - MetricsResource, - generate_latest, - start_http_server, -) +from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._types import Collector from synapse.util import SYNAPSE_VERSION @@ -474,7 +470,6 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None: "Collector", "MetricsResource", "generate_latest", - "start_http_server", "LaterGauge", "InFlightGauge", "GaugeBucketCollector", diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py deleted file mode 100644 index 1459f9d224b3..000000000000 --- a/synapse/metrics/_legacy_exposition.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2015-2019 Prometheus Python Client Developers -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This code is based off `prometheus_client/exposition.py` from version 0.7.1. - -Due to the renaming of metrics in prometheus_client 0.4.0, this customised -vendoring of the code will emit both the old versions that Synapse dashboards -expect, and the newer "best practice" version of the up-to-date official client. -""" -import logging -import math -import threading -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn -from typing import Any, Dict, List, Type, Union -from urllib.parse import parse_qs, urlparse - -from prometheus_client import REGISTRY, CollectorRegistry -from prometheus_client.core import Sample - -from twisted.web.resource import Resource -from twisted.web.server import Request - -logger = logging.getLogger(__name__) -CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" - - -def floatToGoString(d: Union[int, float]) -> str: - d = float(d) - if d == math.inf: - return "+Inf" - elif d == -math.inf: - return "-Inf" - elif math.isnan(d): - return "NaN" - else: - s = repr(d) - dot = s.find(".") - # Go switches to exponents sooner than Python. - # We only need to care about positive values for le/quantile. - if d > 0 and dot > 6: - mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") - return f"{mantissa}e+0{dot - 1}" - return s - - -def sample_line(line: Sample, name: str) -> str: - if line.labels: - labelstr = "{{{0}}}".format( - ",".join( - [ - '{}="{}"'.format( - k, - v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), - ) - for k, v in sorted(line.labels.items()) - ] - ) - ) - else: - labelstr = "" - timestamp = "" - if line.timestamp is not None: - # Convert to milliseconds. - timestamp = f" {int(float(line.timestamp) * 1000):d}" - return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) - - -# Mapping from new metric names to legacy metric names. -# We translate these back to their old names when exposing them through our -# legacy vendored exporter. -# Only this legacy exposition module applies these name changes. -LEGACY_METRIC_NAMES = { - "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits", - "synapse_util_caches_cache_size": "synapse_util_caches_cache:size", - "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size", - "synapse_util_caches_cache": "synapse_util_caches_cache:total", - "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size", - "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits", - "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size", - "synapse_util_caches_response_cache": "synapse_util_caches_response_cache:total", - "synapse_federation_client_sent_pdu_destinations": "synapse_federation_client_sent_pdu_destinations:total", - "synapse_federation_client_sent_pdu_destinations_count": "synapse_federation_client_sent_pdu_destinations:count", - "synapse_admin_mau_current": "synapse_admin_mau:current", - "synapse_admin_mau_max": "synapse_admin_mau:max", - "synapse_admin_mau_registered_reserved_users": "synapse_admin_mau:registered_reserved_users", -} - - -def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: - """ - Generate metrics in legacy format. Modern metrics are generated directly - by prometheus-client. - """ - - output = [] - - for metric in registry.collect(): - if not metric.samples: - # No samples, don't bother. - continue - - # Translate to legacy metric name if it has one. - mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name) - mnewname = metric.name - mtype = metric.type - - # OpenMetrics -> Prometheus - if mtype == "counter": - mnewname = mnewname + "_total" - elif mtype == "info": - mtype = "gauge" - mnewname = mnewname + "_info" - elif mtype == "stateset": - mtype = "gauge" - elif mtype == "gaugehistogram": - mtype = "histogram" - elif mtype == "unknown": - mtype = "untyped" - - # Output in the old format for compatibility. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname} {mtype}\n") - - om_samples: Dict[str, List[str]] = {} - for s in metric.samples: - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - # OpenMetrics specific sample, put in a gauge at the end. - # (these come from gaugehistograms which don't get renamed, - # so no need to faff with mnewname) - om_samples.setdefault(suffix, []).append(sample_line(s, s.name)) - break - else: - newname = s.name.replace(mnewname, mname) - if ":" in newname and newname.endswith("_total"): - newname = newname[: -len("_total")] - output.append(sample_line(s, newname)) - - for suffix, lines in sorted(om_samples.items()): - if emit_help: - output.append( - "# HELP {}{} {}\n".format( - mname, - suffix, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname}{suffix} gauge\n") - output.extend(lines) - - # Get rid of the weird colon things while we're at it - if mtype == "counter": - mnewname = mnewname.replace(":total", "") - mnewname = mnewname.replace(":", "_") - - if mname == mnewname: - continue - - # Also output in the new format, if it's different. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mnewname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mnewname} {mtype}\n") - - for s in metric.samples: - # Get rid of the OpenMetrics specific samples (we should already have - # dealt with them above anyway.) - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - break - else: - sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name) - output.append( - sample_line(s, sample_name.replace(":total", "").replace(":", "_")) - ) - - return "".join(output).encode("utf-8") - - -class MetricsHandler(BaseHTTPRequestHandler): - """HTTP handler that gives metrics from ``REGISTRY``.""" - - registry = REGISTRY - - def do_GET(self) -> None: - registry = self.registry - params = parse_qs(urlparse(self.path).query) - - if "help" in params: - emit_help = True - else: - emit_help = False - - try: - output = generate_latest(registry, emit_help=emit_help) - except Exception: - self.send_error(500, "error generating metric output") - raise - try: - self.send_response(200) - self.send_header("Content-Type", CONTENT_TYPE_LATEST) - self.send_header("Content-Length", str(len(output))) - self.end_headers() - self.wfile.write(output) - except BrokenPipeError as e: - logger.warning( - "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e - ) - - def log_message(self, format: str, *args: Any) -> None: - """Log nothing.""" - - @classmethod - def factory(cls, registry: CollectorRegistry) -> Type: - """Returns a dynamic MetricsHandler class tied - to the passed registry. - """ - # This implementation relies on MetricsHandler.registry - # (defined above and defaulted to REGISTRY). - - # As we have unicode_literals, we need to create a str() - # object for type(). - cls_name = str(cls.__name__) - MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry}) - return MyMetricsHandler - - -class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): - """Thread per request HTTP server.""" - - # Make worker threads "fire and forget". Beginning with Python 3.7 this - # prevents a memory leak because ``ThreadingMixIn`` starts to gather all - # non-daemon threads in a list in order to join on them at server close. - # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the - # same as Python 3.7's ``ThreadingHTTPServer``. - daemon_threads = True - - -def start_http_server( - port: int, addr: str = "", registry: CollectorRegistry = REGISTRY -) -> None: - """Starts an HTTP server for prometheus metrics as a daemon thread""" - CustomMetricsHandler = MetricsHandler.factory(registry) - httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) - t = threading.Thread(target=httpd.serve_forever) - t.daemon = True - t.start() - - -class MetricsResource(Resource): - """ - Twisted ``Resource`` that serves prometheus metrics. - """ - - isLeaf = True - - def __init__(self, registry: CollectorRegistry = REGISTRY): - self.registry = registry - - def render_GET(self, request: Request) -> bytes: - request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) - response = generate_latest(self.registry) - request.setHeader(b"Content-Length", str(len(response))) - return response diff --git a/synapse/metrics/_twisted_exposition.py b/synapse/metrics/_twisted_exposition.py new file mode 100644 index 000000000000..0abcd1495383 --- /dev/null +++ b/synapse/metrics/_twisted_exposition.py @@ -0,0 +1,38 @@ +# Copyright 2015-2019 Prometheus Python Client Developers +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from prometheus_client import REGISTRY, CollectorRegistry, generate_latest + +from twisted.web.resource import Resource +from twisted.web.server import Request + +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" + + +class MetricsResource(Resource): + """ + Twisted ``Resource`` that serves prometheus metrics. + """ + + isLeaf = True + + def __init__(self, registry: CollectorRegistry = REGISTRY): + self.registry = registry + + def render_GET(self, request: Request) -> bytes: + request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) + response = generate_latest(self.registry) + request.setHeader(b"Content-Length", str(len(response))) + return response diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py index 0a22ea3d923c..6e05b043d3cd 100644 --- a/synapse/metrics/common_usage_metrics.py +++ b/synapse/metrics/common_usage_metrics.py @@ -54,7 +54,9 @@ async def get_metrics(self) -> CommonUsageMetrics: async def setup(self) -> None: """Keep the gauges for common usage metrics up to date.""" - await self._update_gauges() + run_as_background_process( + desc="common_usage_metrics_update_gauges", func=self._update_gauges + ) self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 30e689d00d2c..96a661177abd 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,10 +786,12 @@ def invalidate_access_token( ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: - access_token(str): access token + access_token: access token Returns: twisted.internet.defer.Deferred - resolves once the access token @@ -796,6 +800,10 @@ def invalidate_access_token( Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ def invalidate_access_token( if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. @@ -832,7 +840,7 @@ def run_db_interaction( **kwargs: named args to be passed to func Returns: - Deferred[object]: result of func + Result of func """ # type-ignore: See https://github.com/python/mypy/issues/8862 return defer.ensureDeferred( @@ -924,8 +932,7 @@ def get_state_events_in_room( to represent 'any') of the room state to acquire. Returns: - twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: - The filtered state events in the room. + The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( self._storage_controllers.state.get_current_state_ids( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index aa87f0253494..f98affb07bd9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -35,6 +35,7 @@ Membership, RelationTypes, ) +from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -340,13 +341,19 @@ async def _action_for_event_by_user( for user_id, level in notification_levels.items(): notification_levels[user_id] = int(level) + room_version_features = event.room_version.msc3931_push_features + if not room_version_features: + room_version_features = [] + evaluator = PushRuleEvaluator( - _flatten_dict(event), + _flatten_dict(event, room_version=event.room_version), non_bot_room_member_count, sender_power_level, notification_levels, related_events, self._related_event_match_enabled, + room_version_features, + self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) users = rules_by_user.keys() @@ -434,6 +441,7 @@ async def _action_for_event_by_user( def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], + room_version: Optional[RoomVersion] = None, prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -447,6 +455,31 @@ def _flatten_dict( elif isinstance(value, bool): result[".".join(prefix + [key])] = str(value).lower() elif isinstance(value, Mapping): + # do not set `room_version` due to recursion considerations below _flatten_dict(value, prefix=(prefix + [key]), result=result) + # `room_version` should only ever be set when looking at the top level of an event + if ( + room_version is not None + and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features + and isinstance(d, EventBase) + ): + # Room supports extensible events: replace `content.body` with the plain text + # representation from `m.markup`, as per MSC1767. + markup = d.get("content").get("m.markup") + if room_version.identifier.startswith("org.matrix.msc1767."): + markup = d.get("content").get("org.matrix.msc1767.markup") + if markup is not None and isinstance(markup, list): + text = "" + for rep in markup: + if not isinstance(rep, dict): + # invalid markup - skip all processing + break + if rep.get("mimetype", "text/plain") == "text/plain": + rep_text = rep.get("body") + if rep_text is not None and isinstance(rep_text, str): + text = rep_text.lower() + break + result["content.body"] = text + return result diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 5e661f8c73c1..3f4d3fc51ae3 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -153,7 +153,7 @@ async def _serialize_payload(**kwargs) -> JsonDict: argument list. Returns: - dict: If POST/PUT request then dictionary must be JSON serialisable, + If POST/PUT request then dictionary must be JSON serialisable, otherwise must be appropriate for adding as query args. """ return {} diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 3d63645726b9..7c4941c3d3f5 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,11 +13,12 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -62,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -72,11 +78,77 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices +class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): + """Ask master to upload keys for the user and send them out over federation to + update other servers. + + For now, only the master is permitted to handle key upload requests; + any worker can handle key query requests (since they're read-only). + + Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on + the main process to accomplish this. + + Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload + Request format(borrowed and expanded from KeyUploadServlet): + + POST /_synapse/replication/upload_keys_for_user + + { + "user_id": "", + "device_id": "", + "keys": { + ....this part can be found in KeyUploadServlet in rest/client/keys.py.... + } + } + + Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet + + """ + + NAME = "upload_keys_for_user" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: + + return { + "user_id": user_id, + "device_id": device_id, + "keys": keys, + } + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + content = parse_json_object_from_request(request) + + user_id = content["user_id"] + device_id = content["device_id"] + keys = content["keys"] + + results = await self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, keys + ) + + return 200, results + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a807c..000000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a807c..000000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed474..000000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7763ffb2d0c7..56a5c21910d9 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -245,7 +245,7 @@ def lineReceived(self, line: bytes) -> None: self._parse_and_dispatch_line(line) def _parse_and_dispatch_line(self, line: bytes) -> None: - if line.strip() == "": + if line.strip() == b"": # Ignore blank lines return diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116e0..fb73886df061 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d9348801026b..3b2f2d9abbd6 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 85942a2f83f8..61eb0cf0358e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -905,8 +905,9 @@ class PushersRestServlet(RestServlet): @user:server/pushers Returns: - pushers: Dictionary containing pushers information. - total: Number of pushers in dictionary `pushers`. + A dictionary with keys: + pushers: Dictionary containing pushers information. + total: Number of pushers in dictionary `pushers`. """ PATTERNS = admin_patterns("/users/(?P[^/]*)/pushers$") diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2e7..69b803f9f8b2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index f653d2a3e174..ee038c71921a 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -27,6 +27,7 @@ ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict, StreamToken from synapse.util.cancellation import cancellable @@ -43,24 +44,48 @@ class KeyUploadServlet(RestServlet): Content-Type: application/json { - "device_keys": { - "user_id": "", - "device_id": "", - "valid_until_ts": , - "algorithms": [ - "m.olm.curve25519-aes-sha2", - ] - "keys": { - ":": "", + "device_keys": { + "user_id": "", + "device_id": "", + "valid_until_ts": , + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + ":": "", + }, + "signatures:" { + "" { + ":": "" + } + } + }, + "fallback_keys": { + ":": "", + "signed_:": { + "fallback": true, + "key": "", + "signatures": { + "": { + ":": "" + } + } + } + } + "one_time_keys": { + ":": "" }, - "signatures:" { - "" { - ":": "" - } } }, - "one_time_keys": { - ":": "" - }, } + + response, e.g.: + + { + "one_time_key_counts": { + "curve25519": 10, + "signed_curve25519": 20 + } + } + """ PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") @@ -71,6 +96,13 @@ def __init__(self, hs: "HomeServer"): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() + if hs.config.worker.worker_app is None: + # if main process + self.key_uploader = self.e2e_keys_handler.upload_keys_for_user + else: + # then a worker + self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) + async def on_POST( self, request: SynapseRequest, device_id: Optional[str] ) -> Tuple[int, JsonDict]: @@ -109,8 +141,8 @@ async def on_POST( 400, "To upload keys, you must pass device_id when authenticating" ) - result = await self.e2e_keys_handler.upload_keys_for_user( - user_id, device_id, body + result = await self.key_uploader( + user_id=user_id, device_id=device_id, keys=body ) return 200, result diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 05706b598c89..8adced41e5e9 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -350,7 +350,7 @@ async def _complete_login( auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: - result: Dictionary of account information after successful login. + Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518fc5..6d34625ad5d6 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index af642118d259..e249e16e98da 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1284,17 +1284,14 @@ class TimestampLookupRestServlet(RestServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= + GET /_matrix/client/v1/rooms//timestamp_to_event?ts=&dir= { "event_id": ... } """ PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3030" - "/rooms/(?P[^/]*)/timestamp_to_event$" - ), + re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/timestamp_to_event$"), ) def __init__(self, hs: "HomeServer"): @@ -1421,8 +1418,7 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) - if hs.config.experimental.msc3030_enabled: - TimestampLookupRestServlet(hs).register(http_server) + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 180a11ef88ae..3c0a90010b87 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,8 +101,6 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]: "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 "org.matrix.msc2716": self.config.experimental.msc2716_enabled, - # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 - "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 7331a1791f51..c70e1837afd0 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -344,8 +344,8 @@ async def _get_remote_media_impl( download from remote server. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 9b93b9b4f6eb..a48a4de92ae2 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -138,7 +138,7 @@ def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ with self._resize(width, height) as scaled: return self._encode_image(scaled, output_type) @@ -155,7 +155,7 @@ def crop(self, width: int, height: int, output_type: str) -> BytesIO: max_height: The largest possible height. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_width = width diff --git a/synapse/server.py b/synapse/server.py index c788847893e3..981d3c2b43b1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -511,7 +511,7 @@ def get_macaroon_generator(self) -> MacaroonGenerator: ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 698ca742ed66..94025ba41f7d 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -113,9 +113,8 @@ def copy_with_str_subst(x: Any, substitutions: Any) -> Any: """Deep-copy a structure, carrying out string substitutions on any strings Args: - x (object): structure to be copied - substitutions (object): substitutions to be made - passed into the - string '%' operator + x: structure to be copied + substitutions: substitutions to be made - passed into the string '%' operator Returns: copy of x diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 3134cd2d3d6c..a31a2c99a7b8 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -170,11 +170,13 @@ async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str room_id: The room id of the server notices room Returns: - bool: Is the room currently blocked - list: The list of pinned event IDs that are unrelated to limit blocking - This list can be used as a convenience in the case where the block - is to be lifted and the remaining pinned event references need to be - preserved + Tuple of: + Is the room currently blocked + + The list of pinned event IDs that are unrelated to limit blocking + This list can be used as a convenience in the case where the block + is to be lifted and the remaining pinned event references need to be + preserved """ currently_blocked = False pinned_state_event = None diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6f3dd0463e66..833ffec3de31 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -190,6 +190,7 @@ async def compute_state_after_events( room_id: str, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Fetch the state after each of the given event IDs. Resolve them and return. @@ -200,13 +201,18 @@ async def compute_state_after_events( Args: room_id: the room_id containing the given events. event_ids: the events whose state should be fetched and resolved. + await_full_state: if `True`, will block if we do not yet have complete state + at the given `event_id`s, regardless of whether `state_filter` is + satisfied by partial state. Returns: the state dict (a mapping from (event_type, state_key) -> event_id) which holds the resolution of the states after the given event IDs. """ logger.debug("calling resolve_state_groups from compute_state_after_events") - ret = await self.resolve_state_groups_for_events(room_id, event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, event_ids, await_full_state + ) return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 48976dc5705e..33ffef521b87 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -204,9 +204,8 @@ async def add_to_queue( process to to so, calling the per_item_callback for each item. Args: - room_id (str): - task (_EventPersistQueueTask): A _PersistEventsTask or - _UpdateCurrentStateTask to process. + room_id: + task: A _PersistEventsTask or _UpdateCurrentStateTask to process. Returns: the result returned by the `_per_item_callback` passed to diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4717c9728a0b..55bcb90001e9 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -569,15 +569,15 @@ async def _check_safe_to_upsert(self) -> None: retcols=["update_name"], desc="check_background_updates", ) - updates = [x["update_name"] for x in updates] + background_update_names = [x["update_name"] for x in updates] for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: + if update_name not in background_update_names: logger.debug("Now safe to upsert in %s", table) self._unsafe_to_upsert_tables.discard(table) # If there's any updates still running, reschedule to run. - if updates: + if background_update_names: self._clock.call_later( 15.0, run_as_background_process, @@ -1129,7 +1129,6 @@ async def simple_upsert( values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", - lock: bool = True, ) -> bool: """Insert a row with values + insertion_values; on conflict, update with values. @@ -1154,21 +1153,12 @@ async def simple_upsert( requiring that a unique index exist on the column names used to detect a conflict (i.e. `keyvalues.keys()`). - If there is no such index, we can "emulate" an upsert with a SELECT followed - by either an INSERT or an UPDATE. This is unsafe: we cannot make the same - atomicity guarantees that a native upsert can and are very vulnerable to races - and crashes. Therefore if we wish to upsert without an appropriate unique index, - we must either: - - 1. Acquire a table-level lock before the emulated upsert (`lock=True`), or - 2. VERY CAREFULLY ensure that we are the only thread and worker which will be - writing to this table, in which case we can proceed without a lock - (`lock=False`). - - Generally speaking, you should use `lock=True`. If the table in question has a - unique index[*], this class will use a native upsert (which is atomic and so can - ignore the `lock` argument). Otherwise this class will use an emulated upsert, - in which case we want the safer option unless we been VERY CAREFUL. + If there is no such index yet[*], we can "emulate" an upsert with a SELECT + followed by either an INSERT or an UPDATE. This is unsafe unless *all* upserters + run at the SERIALIZABLE isolation level: we cannot make the same atomicity + guarantees that a native upsert can and are very vulnerable to races and + crashes. Therefore to upsert without an appropriate unique index, we acquire a + table-level lock before the emulated upsert. [*]: Some tables have unique indices added to them in the background. Those tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES, @@ -1189,7 +1179,6 @@ async def simple_upsert( values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting desc: description of the transaction, for logging and metrics - lock: True to lock the table when doing the upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) @@ -1209,7 +1198,6 @@ async def simple_upsert( keyvalues, values, insertion_values, - lock=lock, db_autocommit=autocommit, ) except self.engine.module.IntegrityError as e: @@ -1232,7 +1220,6 @@ def simple_upsert_txn( values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, where_clause: Optional[str] = None, - lock: bool = True, ) -> bool: """ Pick the UPSERT method which works best on the platform. Either the @@ -1245,8 +1232,6 @@ def simple_upsert_txn( values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting where_clause: An index predicate to apply to the upsert. - lock: True to lock the table when doing the upsert. Unused when performing - a native upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) @@ -1270,7 +1255,6 @@ def simple_upsert_txn( values, insertion_values=insertion_values, where_clause=where_clause, - lock=lock, ) def simple_upsert_txn_emulated( @@ -1291,14 +1275,15 @@ def simple_upsert_txn_emulated( insertion_values: additional key/values to use only when inserting where_clause: An index predicate to apply to the upsert. lock: True to lock the table when doing the upsert. + Must not be False unless the table has already been locked. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) """ insertion_values = insertion_values or {} - # We need to lock the table :(, unless we're *really* careful if lock: + # We need to lock the table :( self.engine.lock_table(txn, table) def _getwhere(key: str) -> str: @@ -1406,7 +1391,6 @@ async def simple_upsert_many( value_names: Collection[str], value_values: Collection[Collection[Any]], desc: str, - lock: bool = True, ) -> None: """ Upsert, many times. @@ -1418,8 +1402,6 @@ async def simple_upsert_many( value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. Unused when performing - a native upsert. """ # We can autocommit if it safe to upsert @@ -1433,7 +1415,6 @@ async def simple_upsert_many( key_values, value_names, value_values, - lock=lock, db_autocommit=autocommit, ) @@ -1445,7 +1426,6 @@ def simple_upsert_many_txn( key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], - lock: bool = True, ) -> None: """ Upsert, many times. @@ -1457,8 +1437,6 @@ def simple_upsert_many_txn( value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. Unused when performing - a native upsert. """ if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( @@ -1466,7 +1444,12 @@ def simple_upsert_many_txn( ) else: return self.simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values, lock=lock + txn, + table, + key_names, + key_values, + value_names, + value_values, ) def simple_upsert_many_txn_emulated( @@ -1477,7 +1460,6 @@ def simple_upsert_many_txn_emulated( key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], - lock: bool = True, ) -> None: """ Upsert, many times, but without native UPSERT support or batching. @@ -1489,18 +1471,16 @@ def simple_upsert_many_txn_emulated( value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. """ # No value columns, therefore make a blank list so that the following # zip() works correctly. if not value_names: value_values = [() for x in range(len(key_values))] - if lock: - # Lock the table just once, to prevent it being done once per row. - # Note that, according to Postgres' documentation, once obtained, - # the lock is held for the remainder of the current transaction. - self.engine.lock_table(txn, "user_ips") + # Lock the table just once, to prevent it being done once per row. + # Note that, according to Postgres' documentation, once obtained, + # the lock is held for the remainder of the current transaction. + self.engine.lock_table(txn, "user_ips") for keyv, valv in zip(key_values, value_values): _keys = {x: y for x, y in zip(key_names, keyv)} @@ -2075,13 +2055,14 @@ def simple_select_one_txn( retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) + select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + + if keyvalues: + select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),) + txn.execute(select_sql, list(keyvalues.values())) + else: + txn.execute(select_sql) - txn.execute(select_sql, list(keyvalues.values())) row = txn.fetchone() if not row: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 18bb796b193c..913efddcdb4a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -28,7 +28,6 @@ ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -72,12 +71,11 @@ def __init__( # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -99,21 +97,13 @@ def __init__( # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( @@ -463,9 +453,6 @@ async def add_account_data_to_room( content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so simple_upsert will - # retry if there is a conflict. await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", @@ -475,7 +462,6 @@ async def add_account_data_to_room( "account_data_type": account_data_type, }, values={"stream_id": next_id, "content": content_json}, - lock=False, ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) @@ -547,7 +533,6 @@ def _add_account_data_for_user( table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, values={"stream_id": next_id, "content": content_json}, - lock=False, ) # Ignored users get denormalized into a separate table as an optimisation. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 6bf0624b0ab9..ed757539840b 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,7 +20,7 @@ ApplicationService, ApplicationServiceState, AppServiceTransaction, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices @@ -261,7 +261,7 @@ async def create_appservice_txn( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: @@ -274,7 +274,7 @@ async def create_appservice_txn( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -300,7 +300,7 @@ def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction: events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, - one_time_key_counts=one_time_key_counts, + one_time_keys_count=one_time_keys_count, unused_fallback_keys=unused_fallback_keys, device_list_summary=device_list_summary, ) @@ -382,7 +382,7 @@ def _get_oldest_unsent_txn( events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -454,8 +454,6 @@ async def set_appservice_stream_type_pos( table="application_services_state", keyvalues={"as_id": service.id}, values={f"{stream_type}_stream_id": pos}, - # no need to lock when emulating upsert: as_id is a unique key - lock=False, desc="set_appservice_stream_type_pos", ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 14eb5a058500..e81d8c62d789 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -253,6 +253,7 @@ def _invalidate_caches_for_event( if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2c459709e4c3..56fbce440a4b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ def __init__( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). @@ -535,7 +525,7 @@ def _get_device_updates_by_remote_txn( limit: Maximum number of device updates to return Returns: - List: List of device update tuples: + List of device update tuples: - user_id - device_id - stream_id @@ -1451,6 +1441,13 @@ def __init__( self._remove_duplicate_outbound_pokes, ) + self.db_pool.updates.register_background_index_update( + "device_lists_changes_in_room_by_room_index", + index_name="device_lists_changes_in_room_by_room_idx", + table="device_lists_changes_in_room", + columns=["room_id", "stream_id"], + ) + async def _drop_device_list_streams_non_unique_indexes( self, progress: JsonDict, batch_size: int ) -> int: @@ -1747,9 +1744,6 @@ def _update_remote_device_list_cache_entry_txn( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, values={"content": json_encoder.encode(content)}, - # we don't need to lock, because we assume we are the only thread - # updating this user's devices. - lock=False, ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) @@ -1763,9 +1757,6 @@ def _update_remote_device_list_cache_entry_txn( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, - # again, we can assume we are the only thread updating this user's - # extremity. - lock=False, ) async def update_remote_device_list_cache( @@ -1818,9 +1809,6 @@ def _update_remote_device_list_cache_txn( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, - # we don't need to lock, because we can assume we are the only thread - # updating this user's extremity. - lock=False, ) async def add_device_change_to_streams( @@ -2018,27 +2006,48 @@ def _add_device_outbound_room_poke_txn( ) async def get_uncoverted_outbound_room_pokes( - self, limit: int = 10 + self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. + Args: + start_stream_id: Together with `start_room_id`, indicates the position after + which to return device list changes. + start_room_id: Together with `start_stream_id`, indicates the position after + which to return device list changes. + limit: The maximum number of device list changes to return. + Returns: - A list of user ID, device ID, room ID, stream ID and optional opentracing context. + A list of user ID, device ID, room ID, stream ID and optional opentracing + context, in order of ascending (stream ID, room ID). """ sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context FROM device_lists_changes_in_room - WHERE NOT converted_to_destinations - ORDER BY stream_id + WHERE + (stream_id, room_id) > (?, ?) AND + stream_id <= ? AND + NOT converted_to_destinations + ORDER BY stream_id ASC, room_id ASC LIMIT ? """ def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: - txn.execute(sql, (limit,)) + txn.execute( + sql, + ( + start_stream_id, + start_room_id, + # Avoid returning rows if there may be uncommitted device list + # changes with smaller stream IDs. + self._device_list_id_gen.get_current_token(), + limit, + ), + ) return [ ( @@ -2060,49 +2069,25 @@ async def add_device_list_outbound_pokes( user_id: str, device_id: str, room_id: str, - stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - - Marks the associated row in `device_lists_changes_in_room` as handled, - if `stream_id` is provided. """ + if not hosts: + return def add_device_list_outbound_pokes_txn( txn: LoggingTransaction, stream_ids: List[int] ) -> None: - if hosts: - self._add_device_outbound_poke_to_stream_txn( - txn, - user_id=user_id, - device_id=device_id, - hosts=hosts, - stream_ids=stream_ids, - context=context, - ) - - if stream_id: - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) - - if not hosts: - # If there are no hosts then we don't try and generate stream IDs. - return await self.db_pool.runInteraction( - "add_device_list_outbound_pokes", - add_device_list_outbound_pokes_txn, - [], + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_id=device_id, + hosts=hosts, + stream_ids=stream_ids, + context=context, ) async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: @@ -2166,3 +2151,37 @@ def get_pending_remote_device_list_updates_for_room_txn( "get_pending_remote_device_list_updates_for_room", get_pending_remote_device_list_updates_for_room_txn, ) + + async def get_device_change_last_converted_pos(self) -> Tuple[int, str]: + """ + Get the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + + Rows with a strictly greater position where `converted_to_destinations` is + `FALSE` have not been converted. + """ + + row = await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ) + return row["stream_id"], row["room_id"] + + async def set_device_change_last_converted_pos( + self, + stream_id: int, + room_id: str, + ) -> None: + """ + Set the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + """ + + await self.db_pool.simple_update_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + updatevalues={"stream_id": stream_id, "room_id": room_id}, + desc="set_device_change_last_converted_pos", + ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index af59be6b4854..6240f9a75ed3 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -391,10 +391,10 @@ async def get_e2e_room_keys_version_info( Returns: A dict giving the info metadata for this backup version, with fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - etag(int): tag of the keys in the backup + version (str) + algorithm (str) + auth_data (object): opaque dict supplied by the client + etag (int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2a4f58ed928b..643c47d608da 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -33,7 +33,7 @@ from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace @@ -412,10 +412,9 @@ async def get_e2e_one_time_keys( """Retrieve a number of one-time keys for a user Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve + user_id: id of user to get keys for + device_id: id of device to get keys for + key_ids: list of key ids (excluding algorithm) to retrieve Returns: A map from (algorithm, key_id) to json string for key @@ -515,7 +514,7 @@ def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: async def count_bulk_e2e_one_time_keys_for_as( self, user_ids: Collection[str] - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: """ Counts, in bulk, the one-time keys for all the users specified. Intended to be used by application services for populating OTK counts in @@ -529,7 +528,7 @@ async def count_bulk_e2e_one_time_keys_for_as( def _count_bulk_e2e_one_time_keys_txn( txn: LoggingTransaction, - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "user_id", user_ids ) @@ -542,7 +541,7 @@ def _count_bulk_e2e_one_time_keys_txn( """ txn.execute(sql, user_parameters) - result: TransactionOneTimeKeyCounts = {} + result: TransactionOneTimeKeysCount = {} for user_id, device_id, algorithm, count in txn: # We deliberately construct empty dictionaries for diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c7962cfa410a..9c2b38df401e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1686,7 +1686,6 @@ async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: }, insertion_values={}, desc="insert_insertion_extremity", - lock=False, ) async def insert_received_event_to_staging( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c4acff5be629..0f097a2927c1 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1279,9 +1279,10 @@ def _filter_events_and_contexts_for_duplicates( Pick the earliest non-outlier if there is one, else the earliest one. Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts: + Returns: - list[(EventBase, EventContext)]: filtered list + filtered list """ new_events_and_contexts: OrderedDict[ str, Tuple[EventBase, EventContext] @@ -1307,9 +1308,8 @@ def _update_room_depths_txn( """Update min_depth for each room Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: @@ -1580,13 +1580,11 @@ def _update_metadata_tables_txn( """Update all the miscellaneous tables for new events Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + txn: db connection + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do @@ -2051,6 +2049,10 @@ def _handle_redact_relations( self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) + ) if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4c921d04c7db..f25407c9c1be 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -60,7 +60,6 @@ run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import ( EventsStream, @@ -219,26 +218,20 @@ def __init__( # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( @@ -1638,7 +1631,7 @@ async def get_room_complexity(self, room_id: str) -> Dict[str, float]: room_id: The room ID to query. Returns: - dict[str:float] of complexity version to complexity. + Map of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index efd136a86474..db9a24db5ec6 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -217,7 +217,7 @@ async def reap_monthly_active_users(self) -> None: def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: - reserved_users (tuple): reserved users to preserve + reserved_users: reserved users to preserve """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) @@ -370,8 +370,8 @@ def upsert_monthly_active_user_txn( should not appear in the MAU stats). Args: - txn (cursor): - user_id (str): user to add/update + txn: + user_id: user to add/update """ assert ( self._update_on_this_worker @@ -401,7 +401,7 @@ async def populate_monthly_active_users(self, user_id: str) -> None: add the user to the monthly active tables Args: - user_id(str): the user_id to query + user_id: the user_id to query """ assert ( self._update_on_this_worker diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f61270f..d4c64c46ad44 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -85,7 +84,10 @@ def _load_rules( push_rules = PushRules(ruleslist) filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + push_rules, + enabled_map, + msc3664_enabled=experimental_config.msc3664_enabled, + msc1767_enabled=experimental_config.msc1767_enabled, ) return filtered_rules @@ -111,14 +113,14 @@ def __init__( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d4552..40fd781a6ab1 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ def __init__( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", @@ -331,14 +325,11 @@ async def get_throttle_params_by_room( async def set_throttle_params( self, pusher_id: str, room_id: str, params: ThrottleParams ) -> None: - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so simple_upsert will retry await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms}, desc="set_throttle_params", - lock=False, ) async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int: @@ -595,8 +586,6 @@ async def add_pusher( device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: - # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -615,7 +604,6 @@ async def add_pusher( "device_id": device_id, }, desc="add_pusher", - lock=False, ) user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 2ec54f87110c..1616ed45ced4 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -28,7 +28,6 @@ from synapse.api.constants import EduTypes from synapse.api.errors import StoreError -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -63,6 +62,9 @@ def __init__( hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -89,14 +91,12 @@ def __init__( # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 5167089e03b4..31f0f2bd3d84 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -953,7 +953,7 @@ def get_user_id_by_threepid_txn( """Returns user id from threepid Args: - txn (cursor): + txn: medium: threepid medium e.g. email address: threepid address e.g. me@example.com @@ -1283,8 +1283,8 @@ def set_expiration_date_for_user_txn( """Sets an expiration date to the account with the given user ID. Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be + user_id: User ID to set an expiration date for. + use_delta: If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ca431002c8d7..aea96e9d2478 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,6 +20,7 @@ FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -81,8 +82,6 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -245,13 +244,17 @@ def _get_recent_references_for_event_txn( txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -260,9 +263,11 @@ def _get_recent_references_for_event_txn( # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -394,111 +399,195 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool: ) return result is not None - @cached(tree=True) - async def get_aggregation_groups_for_event( - self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + @cached() + async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" + ) + async def get_aggregation_groups_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[JsonDict]]]: + """Get a list of annotations on the given events, grouped by event type and aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ + # The number of entries to return per event ID. + limit = 5 - args = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - limit, - ] + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.ANNOTATION) - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + sql = f""" + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE + {clause} + AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_events_txn( txn: LoggingTransaction, - ) -> List[JsonDict]: + ) -> Mapping[str, List[JsonDict]]: txn.execute(sql, args) - return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + result: Dict[str, List[JsonDict]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + event_results = result.setdefault(event_id, []) + + # Limit the number of results per event ID. + if len(event_results) == limit: + continue + + event_results.append({"type": type, "key": key, "count": count}) + + return result return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn ) async def get_aggregation_groups_for_users( - self, - event_id: str, - room_id: str, - limit: int, - users: FrozenSet[str] = frozenset(), - ) -> Dict[Tuple[str, str], int]: + self, event_ids: Collection[str], users: FrozenSet[str] + ) -> Dict[str, Dict[Tuple[str, str], int]]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. users: The users to fetch information for. Returns: - A map of (event type, aggregation key) to a count of users. + A map of event ID to a map of (event type, aggregation key) to a + count of users. """ if not users: return {} - args: List[Union[str, int]] = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - ] + events_sql, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "sender", users + self.database_engine, "annotation.sender", users ) args.extend(users_args) + args.append(RelationTypes.ANNOTATION) sql = f""" - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE {events_sql} AND {users_sql} AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], int]: - txn.execute(sql, args + [limit]) + ) -> Dict[str, Dict[Tuple[str, str], int]]: + txn.execute(sql, args) - return {(row[0], row[1]): row[2] for row in txn} + result: Dict[str, Dict[Tuple[str, str], int]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + result.setdefault(event_id, {})[(type, key)] = count + + return result return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) + @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d97f8f60e7c..1309bfd374c3 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -912,7 +912,11 @@ def _get_media_mxcs_in_room_txn( event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") + info = content.get("info") + if isinstance(info, dict): + thumbnail_url = info.get("thumbnail_url") + else: + thumbnail_url = None for url in (content_url, thumbnail_url): if not url: @@ -1843,9 +1847,6 @@ async def upsert_room_on_join( "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, ) async def store_partial_state_room( @@ -1966,9 +1967,6 @@ async def maybe_store_room_on_outlier_membership( "creator": "", "has_auth_chain_index": has_auth_chain_index, }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, ) async def set_room_is_public(self, room_id: str, is_public: bool) -> None: @@ -2057,7 +2055,8 @@ async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: Args: report_id: ID of reported event in database Returns: - event_report: json list of information from event report + JSON dict of information from an event report or None if the + report does not exist. """ def _get_event_report_txn( @@ -2130,8 +2129,9 @@ async def get_event_reports_paginate( user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: - event_reports: json list of event reports - count: total number of event reports matching the filter criteria + Tuple of: + json list of event reports + total number of event reports matching the filter criteria """ def _get_event_reports_paginate_txn( diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index 39e80f6f5b11..131f357d04d5 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -44,6 +44,4 @@ async def store_state_group_id_for_event_id( table="event_to_state_groups", keyvalues={"event_id": event_id}, values={"state_group": state_group_id, "event_id": event_id}, - # Unique constraint on event_id so we don't have to lock - lock=False, ) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ddb25b5cea7f..044435deab4e 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -185,9 +185,8 @@ async def _populate_user_directory_process_rooms( - who should be in the user_directory. Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. + progress + batch_size: Maximum number of state events to process per cycle. Returns: number of events processed. @@ -482,7 +481,6 @@ def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None: table="user_directory", keyvalues={"user_id": user_id}, values={"display_name": display_name, "avatar_url": avatar_url}, - lock=False, # We're only inserter ) if isinstance(self.database_engine, PostgresEngine): @@ -512,7 +510,6 @@ def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None: table="user_directory_search", keyvalues={"user_id": user_id}, values={"value": value}, - lock=False, # We're only inserter ) else: # This should be unreachable. @@ -708,10 +705,10 @@ async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]: Returns the rooms that a user is in. Args: - user_id(str): Must be a local user + user_id: Must be a local user Returns: - list: user_id + List of room IDs """ rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index a7fcc564a992..4a4ad0f49288 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -93,13 +93,6 @@ def _get_state_groups_from_groups_txn( results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in @@ -110,31 +103,91 @@ def _get_state_groups_from_groups_txn( # against `state_groups_state` to fetch the latest state. # It assumes that previous state groups are always numerically # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. sql = """ - WITH RECURSIVE state(state_group) AS ( + WITH RECURSIVE sgs(state_group) AS ( VALUES(?::bigint) UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s + SELECT prev_state_group FROM state_group_edges e, sgs s WHERE s.state_group = e.state_group ) - SELECT DISTINCT ON (type, state_key) - type, state_key, event_id - FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) %s - ORDER BY type, state_key, state_group DESC + %s """ + overall_select_query_args: List[Union[int, str]] = [] + + # This is an optimization to create a select clause per-condition. This + # makes the query planner a lot smarter on what rows should pull out in the + # first place and we end up with something that takes 10x less time to get a + # result. + use_condition_optimization = ( + not state_filter.include_others and not state_filter.is_full() + ) + state_filter_condition_combos: List[Tuple[str, Optional[str]]] = [] + # We don't need to caclculate this list if we're not using the condition + # optimization + if use_condition_optimization: + for etype, state_keys in state_filter.types.items(): + if state_keys is None: + state_filter_condition_combos.append((etype, None)) + else: + for state_key in state_keys: + state_filter_condition_combos.append((etype, state_key)) + # And here is the optimization itself. We don't want to do the optimization + # if there are too many individual conditions. 10 is an arbitrary number + # with no testing behind it but we do know that we specifically made this + # optimization for when we grab the necessary state out for + # `filter_events_for_client` which just uses 2 conditions + # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`). + if use_condition_optimization and len(state_filter_condition_combos) < 10: + select_clause_list: List[str] = [] + for etype, skey in state_filter_condition_combos: + if skey is None: + where_clause = "(type = ?)" + overall_select_query_args.extend([etype]) + else: + where_clause = "(type = ? AND state_key = ?)" + overall_select_query_args.extend([etype, skey]) + + select_clause_list.append( + f""" + ( + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + INNER JOIN sgs USING (state_group) + WHERE {where_clause} + ORDER BY type, state_key, state_group DESC + ) + """ + ) + + overall_select_clause = " UNION ".join(select_clause_list) + else: + where_clause, where_args = state_filter.make_sql_filter_clause() + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + overall_select_query_args.extend(where_args) + + overall_select_clause = f""" + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM sgs + ) {where_clause} + ORDER BY type, state_key, state_group DESC + """ + for group in groups: args: List[Union[int, str]] = [group] - args.extend(where_args) + args.extend(overall_select_query_args) - txn.execute(sql % (where_clause,), args) + txn.execute(sql % (overall_select_clause,), args) for row in txn: typ, state_key, event_id = row key = (intern_string(typ), intern_string(state_key)) @@ -142,6 +195,12 @@ def _get_state_groups_from_groups_txn( else: max_entries_returned = state_filter.max_entries_returned() + where_clause, where_args = state_filter.make_sql_filter_clause() + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: diff --git a/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql new file mode 100644 index 000000000000..93d7fcb79b8c --- /dev/null +++ b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql @@ -0,0 +1,53 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Prior to this schema delta, we tracked the set of unconverted rows in +-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows +-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag +-- would be set. +-- +-- After this schema delta, the `converted_to_destinations` is still populated like +-- before, but the set of unconverted rows is determined by the `stream_id` in the new +-- `device_lists_changes_converted_stream_position` table. +-- +-- If rolled back, Synapse will re-send all device list changes that happened since the +-- schema delta. + +CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + -- The (stream id, room id) of the last row in `device_lists_changes_in_room` that + -- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger + -- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been + -- converted. + stream_id BIGINT NOT NULL, + -- `room_id` may be an empty string, which compares less than all valid room IDs. + room_id TEXT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES ( + ( + SELECT COALESCE( + -- The last converted stream id is the smallest unconverted stream id minus + -- one. + MIN(stream_id) - 1, + -- If there is no unconverted stream id, the last converted stream id is the + -- largest stream id. + -- Otherwise, pick 1, since stream ids start at 2. + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room) + ) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations + ), + '' +); diff --git a/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql b/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql new file mode 100644 index 000000000000..3725022a1336 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Adds an index on `device_lists_changes_in_room (room_id, stream_id)`, which +-- speeds up `/sync` queries. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7313, 'device_lists_changes_in_room_by_room_index', '{}'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b6615..0d7108f01b41 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ def __init__( column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ def __init__( self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ def manager() -> Generator[Sequence[int], None, None]: return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if not self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f331e1af16e0..619eb7f601de 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -73,6 +73,19 @@ def get_current_token(self) -> StreamToken: ) return token + @trace + async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: + """Get the start token for a given room to be used to paginate + events. + + The returned token does not have the current values for fields other + than `room`, since they are not used during pagination. + + Returns: + The start token for pagination. + """ + return StreamToken.START + @trace async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: """Get the current token for a given room to be used to paginate diff --git a/synapse/types.py b/synapse/types.py index 773f0438d5bd..f2d436ddc38c 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -143,8 +143,8 @@ def deserialize( Requester. Args: - store (DataStore): Used to convert AS ID to AS object - input (dict): A dict produced by `serialize` + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` Returns: Requester diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7f1d41eb3c7a..d24c4f68c4da 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -217,7 +217,8 @@ async def concurrently_execute( limit: Maximum number of conccurent executions. Returns: - Deferred: Resolved when all function invocations have finished. + None, when all function invocations have finished. The return values + from those functions are discarded. """ it = iter(args) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f7c3a6794ed0..9387632d0d74 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -197,7 +197,7 @@ def register_cache( resize_callback: A function which can be called to resize the cache. Returns: - CacheMetric: an object which provides inc_{hits,misses,evictions} methods + an object which provides inc_{hits,misses,evictions} methods """ if resizable: if not resize_callback: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bcb1cba3620a..bf7bd351e0cc 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -153,7 +153,7 @@ def get( Args: key: callback: Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics + update_metrics: whether to update the cache hit rate metrics Returns: A Deferred which completes with the result. Note that this may later fail diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index ba69cfc005d8..bbe5ddeb6452 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -565,7 +565,7 @@ def cachedList( is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in the cache gets passed to the original function, which is expected to results - in a map of key to value for each passed value. THe new results are stored in the + in a map of key to value for each passed value. The new results are stored in the original cache. Note that any missing values are cached as None. Args: diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index fa91479c97f6..5eaf70c7abba 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -169,10 +169,11 @@ def get( if it is in the cache. Returns: - DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry` - will contain include the keys that are in the cache. If None then - will either return the full dict if in the cache, or the empty - dict (with `full` set to False) if it isn't. + If `dict_keys` is not None then `DictionaryEntry` will contain include + the keys that are in the cache. + + If None then will either return the full dict if in the cache, or the + empty dict (with `full` set to False) if it isn't. """ if dict_keys is None: # The caller wants the full set of dictionary keys for this cache key diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c6a5d0dfc0a9..01ad02af6703 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -207,7 +207,7 @@ def set_cache_factor(self, factor: float) -> bool: items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ new_size = int(self._original_max_size * factor) if new_size != self._max_size: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index aa93109d1380..dcf0eac3bf08 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -389,11 +389,11 @@ def __init__( cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - cache_type (type): + cache_type: type of underlying cache to be used. Typically one of dict or TreeCache. - size_callback (func(V) -> int | None): + size_callback: metrics_collection_callback: metrics collection callback. This is called early in the metrics @@ -403,7 +403,7 @@ def __init__( Ignored if cache_name is None. - apply_cache_factor_from_config (bool): If true, `max_size` will be + apply_cache_factor_from_config: If true, `max_size` will be multiplied by a cache factor derived from the homeserver config clock: @@ -796,7 +796,7 @@ def set_cache_factor(self, factor: float) -> bool: items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ if not self.apply_cache_factor_from_config: return False diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 9f64fed0d764..2aceb1a47fb6 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -183,7 +183,7 @@ def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None] # Handle request ... Args: - host (str): Origin of incoming request. + host: Origin of incoming request. Returns: context manager which returns a deferred. diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 1e9c2faa644a..54bc7589fd51 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -48,7 +48,7 @@ async def check_3pid_allowed( registration: whether we want to bind the 3PID as part of registering a new user. Returns: - bool: whether the 3PID medium/address is allowed to be added to this HS + whether the 3PID medium/address is allowed to be added to this HS """ if not await hs.get_password_auth_provider().is_3pid_allowed( medium, address, registration diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 177e198e7e75..b1ec7f4bd8bb 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -90,10 +90,10 @@ def fetch(self, now: int) -> List[T]: """Fetch any objects that have timed out Args: - now (ms): Current time in msec + now: Current time in msec Returns: - list: List of objects that have timed out + List of objects that have timed out """ now_key = int(now / self.bucket_size) diff --git a/synapse/visibility.py b/synapse/visibility.py index 40a9c5b53f83..b443857571b7 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -563,7 +563,8 @@ def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: async def filter_events_for_server( storage: StorageControllers, - server_name: str, + target_server_name: str, + local_server_name: str, events: List[EventBase], redact: bool = True, check_history_visibility_only: bool = False, @@ -603,7 +604,7 @@ def check_event_is_visible( # if the server is either in the room or has been invited # into the room. for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == server_name + assert get_domain_from_id(ev.state_key) == target_server_name memtype = ev.membership if memtype == Membership.JOIN: @@ -622,6 +623,24 @@ def check_event_is_visible( # to no users having been erased. erased_senders = {} + # Filter out non-local events when we are in the middle of a partial join, since our servers + # list can be out of date and we could leak events to servers not in the room anymore. + # This can also be true for local events but we consider it to be an acceptable risk. + + # We do this check as a first step and before retrieving membership events because + # otherwise a room could be fully joined after we retrieve those, which would then bypass + # this check but would base the filtering on an outdated view of the membership events. + + partial_state_invisible_events = set() + if not check_history_visibility_only: + for e in events: + sender_domain = get_domain_from_id(e.sender) + if ( + sender_domain != local_server_name + and await storage.main.is_partial_state_room(e.room_id) + ): + partial_state_invisible_events.add(e) + # Let's check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). @@ -636,7 +655,7 @@ def check_event_is_visible( if event_to_history_vis[e.event_id] not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) ], - server_name, + target_server_name, ) to_return = [] @@ -645,6 +664,10 @@ def check_event_is_visible( visible = check_event_is_visible( event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) ) + + if e in partial_state_invisible_events: + visible = False + if visible and not erased: to_return.append(e) elif redact: diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0b22afdc7598..0a1ae83a2bbf 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -69,7 +69,7 @@ def test_single_service_up_txn_sent(self): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -96,7 +96,7 @@ def test_single_service_down(self): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -125,7 +125,7 @@ def test_single_service_up_txn_not_sent(self): events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 820a1a54e2e0..63628aa6b066 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -469,6 +469,18 @@ async def get_json(destination, path, **kwargs): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) + def test_keyid_containing_forward_slash(self) -> None: + """We should url-encode any url unsafe chars in key ids. + + Detects https://github.com/matrix-org/synapse/issues/14488. + """ + fetcher = ServerKeyFetcher(self.hs) + self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) + + self.http_client.get_json.assert_called_once() + args, kwargs = self.http_client.get_json.call_args + self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") + class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index f1e357764ff4..01f147418b9b 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -83,6 +83,83 @@ def test_send_receipts(self): ], ) + @override_config({"send_federation": True}) + def test_send_receipts_thread(self): + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) + mock_send_transaction.return_value = make_awaitable({}) + + # Create receipts for: + # + # * The same room / user on multiple threads. + # * A different user in the same room. + sender = self.hs.get_federation_sender() + for user, thread in ( + ("alice", None), + ("alice", "thread"), + ("bob", None), + ("bob", "diff-thread"), + ): + receipt = ReadReceipt( + "room_id", + "m.read", + user, + ["event_id"], + thread_id=thread, + data={"ts": 1234}, + ) + self.successResultOf( + defer.ensureDeferred(sender.send_read_receipt(receipt)) + ) + + self.pump() + + # expect a call to send_transaction with two EDUs to separate threads. + mock_send_transaction.assert_called_once() + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + # Note that the ordering of the EDUs doesn't matter. + self.assertCountEqual( + data["edus"], + [ + { + "edu_type": EduTypes.RECEIPT, + "content": { + "room_id": { + "m.read": { + "alice": { + "event_ids": ["event_id"], + "data": {"ts": 1234, "thread_id": "thread"}, + }, + "bob": { + "event_ids": ["event_id"], + "data": {"ts": 1234, "thread_id": "diff-thread"}, + }, + } + } + }, + }, + { + "edu_type": EduTypes.RECEIPT, + "content": { + "room_id": { + "m.read": { + "alice": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, + }, + "bob": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, + }, + } + } + }, + }, + ], + ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 144e49d0fd9c..9ed26d87a753 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.handlers.appservice import ApplicationServicesHandler @@ -1123,7 +1123,7 @@ def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pd # Capture what was sent as an AS transaction. self.send_mock.assert_called() last_args, _last_kwargs = self.send_mock.call_args - otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS] unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ self.ARG_FALLBACK_KEYS ] diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b8f..ce7525e29c0a 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ def test_device_is_created_if_doesnt_exist(self) -> None: self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ def test_device_is_preserved_if_exists(self) -> None: self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ def test_device_id_is_made_up_if_unspecified(self) -> None: ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ def _record_user( class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ def test_dehydrate_and_rehydrate_device(self) -> None: ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c96dc6caf2d5..c5981ff9657f 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -15,6 +15,7 @@ from typing import Optional from unittest.mock import Mock, call +from parameterized import parameterized from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState @@ -37,6 +38,7 @@ from synapse.types import UserID, get_domain_from_id from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): @@ -505,7 +507,7 @@ def test_last_active(self): self.assertEqual(state, new_state) -class PresenceHandlerTestCase(unittest.HomeserverTestCase): +class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def prepare(self, reactor, clock, hs): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() @@ -716,20 +718,47 @@ def test_set_presence_from_syncing_keeps_status(self): # our status message should be the same as it was before self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_keeps_busy(self): - """Test that presence set by syncing doesn't affect busy status""" - # while this isn't the default - self.presence_handler._busy_presence_enabled = True + @parameterized.expand([(False,), (True,)]) + @unittest.override_config( + { + "experimental_features": { + "msc3026_enabled": True, + }, + } + ) + def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + """Test that presence set by syncing doesn't affect busy status + Args: + test_with_workers: If True, check the presence state of the user by calling + /sync against a worker, rather than the main process. + """ user_id = "@test:server" status_msg = "I'm busy!" + # By default, we call /sync against the main process. + worker_to_sync_against = self.hs + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "presence_writer"} + ) + + # Set presence to BUSY self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) + # Perform a sync with a presence state other than busy. This should NOT change + # our presence status; we only change from busy if we explicitly set it via + # /presence/*. self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + worker_to_sync_against.get_presence_handler().user_syncing( + user_id, True, PresenceState.ONLINE + ) ) + # Check against the main process that the user's presence did not change. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py new file mode 100644 index 000000000000..137deab138b5 --- /dev/null +++ b/tests/handlers/test_sso.py @@ -0,0 +1,145 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import BinaryIO, Callable, Dict, List, Optional, Tuple +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers + +from synapse.api.errors import Codes, SynapseError +from synapse.http.client import RawHeaders +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.test_utils import SMALL_PNG, FakeResponse + + +class TestSSOHandler(unittest.HomeserverTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + self.http_client.user_agent = b"Synapse Test" + hs = self.setup_test_homeserver( + proxied_blacklisted_http_client=self.http_client + ) + return hs + + async def test_set_avatar(self) -> None: + """Tests successfully setting the avatar of a newly created user""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # Ensure avatar is set on this newly created user, + # so no need to compare for the exact image + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertIsNot(profile["avatar_url"], None) + + @unittest.override_config({"max_avatar_size": 1}) + async def test_set_avatar_too_big_image(self) -> None: + """Tests that saving an avatar fails when it is too big""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]}) + async def test_set_avatar_incorrect_mime_type(self) -> None: + """Tests that saving an avatar fails when its mime type is not allowed""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + async def test_skip_saving_avatar_when_not_changed(self) -> None: + """Tests whether saving of avatar correctly skips if the avatar hasn't + changed""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + # set avatar for the first time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # get avatar picture for comparison after another attempt + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + url_to_match = profile["avatar_url"] + + # set same avatar for the second time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # compare avatar picture's url from previous step + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertEqual(profile["avatar_url"], url_to_match) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + + fake_response = FakeResponse(code=404) + if url == "http://my.server/me.png": + fake_response = FakeResponse( + code=200, + headers=Headers( + {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]} + ), + body=SMALL_PNG, + ) + + if max_size is not None and max_size < len(SMALL_PNG): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, + ) + + if is_allowed_content_type and not is_allowed_content_type("image/png"): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + ( + "Requested file's content type not allowed for this operation: %s" + % "image/png" + ), + ) + + output_stream.write(fake_response.body) + + return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200 diff --git a/tests/http/__init__.py b/tests/http/__init__.py index e74f7f5b48f1..093537adef52 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import os.path import subprocess +from typing import List from zope.interface import implementer @@ -70,14 +71,14 @@ def get_test_key_file(): """ -def create_test_cert_file(sanlist): +def create_test_cert_file(sanlist: List[bytes]) -> str: """build an x509 certificate file Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert Returns: - str: the path to the file + The path to the file """ global cert_file_count csr_filename = "server.csr" diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 02cef6f876b2..058ca57e559d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -778,8 +778,11 @@ def _test_sending_local_online_presence_to_local_user( worker process. The test users will still sync with the main process. The purpose of testing with a worker is to check whether a Synapse module running on a worker can inform other workers/ the main process that they should include additional presence when a user next syncs. + If this argument is True, `test_case` MUST be an instance of BaseMultiWorkerStreamTestCase. """ if test_with_workers: + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) + # Create a worker process to make module_api calls against worker_hs = test_case.make_worker_hs( "synapse.app.generic_worker", {"worker_name": "presence_writer"} diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index c824684ecb6f..9852753cb92d 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -62,6 +62,8 @@ def _get_evaluator( power_levels.get("notifications", {}), {} if related_events is None else related_events, True, + event.room_version.msc3931_push_features, + True, ) def test_display_name(self) -> None: diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 121f3d8d6517..3029a16ddad9 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -542,8 +542,13 @@ def handle_command(self, command, *args): self.send("OK") elif command == b"GET": self.send(None) + + # Connection keep-alives. + elif command == b"PING": + self.send("PONG") + else: - raise Exception("Unknown command") + raise Exception(f"Unknown command: {command}") def send(self, msg): """Send a message back to the client.""" diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index bb427abf7b24..becbc88947a9 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -143,6 +143,7 @@ def test_invites(self): self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + assert event.internal_metadata.stream_ordering is not None self.replicate() @@ -230,6 +231,7 @@ def test_get_rooms_for_user_with_stream_ordering(self): j2 = self.persist( type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) + assert j2.internal_metadata.stream_ordering is not None self.replicate() expected_pos = PersistedEventPosition( @@ -287,6 +289,7 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): ) ) self.replicate() + assert j2.internal_metadata.stream_ordering is not None event_source = RoomEventSource(self.hs) event_source.store = self.slaved_store @@ -336,10 +339,10 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): event_id = 0 - def persist(self, backfill=False, **kwargs): + def persist(self, backfill=False, **kwargs) -> FrozenEvent: """ Returns: - synapse.events.FrozenEvent: The event that was persisted. + The event that was persisted. """ event, context = self.build_event(**kwargs) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 13aa5eb51aa5..96cdf2c45b16 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -15,8 +15,9 @@ import os from typing import Optional, Tuple +from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from twisted.web.server import Request @@ -102,7 +103,7 @@ def _get_media_req( ) # fish the test server back out of the server-side TLS protocol. - http_server = server_tls_protocol.wrappedProtocol + http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment] # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) @@ -238,16 +239,15 @@ def get_connection_factory(): return test_server_connection_factory -def _build_test_server(connection_creator): +def _build_test_server( + connection_creator: IOpenSSLServerConnectionCreator, +) -> TLSMemoryBIOProtocol: """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + connection_creator: thing to build SSL connections Returns: TLSMemoryBIOProtocol diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f9282..03f2112b07ce 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d156be82b04d..e0f5d54abab0 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1857,6 +1857,46 @@ def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: self.assertIn("chunk", channel.json_body) self.assertIn("end", channel.json_body) + def test_room_messages_backward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in backwards, this is the first event + self.assertEqual(chunk[0]["event_id"], latest_event_id) + + def test_room_messages_forward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in forward, this is the last event + self.assertEqual(chunk[5]["event_id"], latest_event_id) + def test_room_messages_purge(self) -> None: """Test room messages can be retrieved by an admin that isn't in the room.""" store = self.hs.get_datastores().main diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index e3d801f7a88f..b86f341ff5b4 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None: # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None: bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) def test_nested_thread(self) -> None: """ diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e919e089cb3d..b4daace55617 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -3546,11 +3546,6 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = {"msc3030_enabled": True} - return config - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._storage_controllers = self.hs.get_storage_controllers() @@ -3592,7 +3587,7 @@ def test_no_outliers(self) -> None: channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + f"/_matrix/client/v1/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", access_token=self.room_owner_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 9980158ec66a..71bdfe68cdbb 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -352,14 +353,15 @@ def test_invite_with_notice(self): self.assertTrue(notice_in_room, "No server notice in room") - def _trigger_notice_and_join(self): + def _trigger_notice_and_join(self) -> Tuple[str, str, str]: """Creates enough active users to hit the MAU limit and trigger a system notice about it, then joins the system notices room with one of the users created. Returns: - user_id (str): The ID of the user that joined the room. - tok (str): The access token of the user that joined the room. - room_id (str): The ID of the room that's been joined. + A tuple of: + user_id: The ID of the user that joined the room. + tok: The access token of the user that joined the room. + room_id: The ID of the room that's been joined. """ user_id = None tok = None diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f37505b6cf6a..8e7db2c4ec39 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ def add_device_change(self, user_id, device_ids, host): """ for device_id in device_ids: - stream_id = self.get_success( + self.get_success( self.store.add_device_change_to_streams( user_id, [device_id], ["!some:room"] ) @@ -39,7 +39,6 @@ def add_device_change(self, user_id, device_ids, host): user_id=user_id, device_id=device_id, room_id="!some:room", - stream_id=stream_id, hosts=[host], context={}, ) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 2550828af9ad..d0bfae2bdc65 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from prometheus_client import generate_latest -from synapse.metrics import REGISTRY, generate_latest +from synapse.metrics import REGISTRY from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -56,8 +57,8 @@ def test_exposed_to_prometheus(self): items = list( filter( - lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY, emit_help=False).split(b"\n"), + lambda x: b"synapse_forward_extremities_" in x and b"# HELP" not in x, + generate_latest(REGISTRY).split(b"\n"), ) ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 2d8d1f860fd5..d6a2b8d2743e 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -16,15 +16,157 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import IncorrectDatabaseSetup -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS +class StreamIdGeneratorTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") + + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + table="foobar", + column="stream_id", + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -48,9 +190,9 @@ def _setup_db(self, txn: LoggingTransaction) -> None: ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -446,7 +588,7 @@ def test_sequence_consistency(self) -> None: self._insert_row_with_id("master", 3) # Now we add a row *without* updating the stream ID - def _insert(txn): + def _insert(txn: Cursor) -> None: txn.execute("INSERT INTO foobar VALUES (26, 'master')") self.get_success(self.db_pool.runInteraction("_insert", _insert)) @@ -481,9 +623,9 @@ def _setup_db(self, txn: LoggingTransaction) -> None: ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -617,9 +759,9 @@ def _setup_db(self, txn: LoggingTransaction) -> None: ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -641,7 +783,7 @@ def _insert_rows( instance_name: str, number: int, update_stream_table: bool = True, - ): + ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ diff --git a/tests/test_visibility.py b/tests/test_visibility.py index db2593c5fbdb..76d7484b2936 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -62,7 +62,7 @@ def test_filtering(self) -> None: filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "hs", events_to_filter ) ) @@ -84,7 +84,7 @@ def test_filter_outlier(self) -> None: self.assertEqual( self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier] + self._storage_controllers, "remote_hs", "hs", [outlier] ) ), [outlier], @@ -95,7 +95,7 @@ def test_filter_outlier(self) -> None: filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier, evt] + self._storage_controllers, "remote_hs", "local_hs", [outlier, evt] ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") @@ -107,7 +107,7 @@ def test_filter_outlier(self) -> None: # be redacted) filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "other_server", [outlier, evt] + self._storage_controllers, "other_server", "local_hs", [outlier, evt] ) ) self.assertEqual(filtered[0], outlier) @@ -142,7 +142,7 @@ def test_erased_user(self) -> None: # ... and the filtering happens. filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "local_hs", events_to_filter ) ) diff --git a/tests/unittest.py b/tests/unittest.py index 5116be338ee0..a120c2976ccd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -360,13 +360,13 @@ def wait_for_background_updates(self) -> None: store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock): """ Make and return a homeserver. Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. + clock: The Clock, associated with the reactor. Returns: A homeserver suitable for testing. @@ -426,9 +426,8 @@ def prepare( Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. - homeserver (synapse.server.HomeServer): The HomeServer to test - against. + clock: The Clock, associated with the reactor. + homeserver: The HomeServer to test against. Function to optionally be overridden in subclasses. """ @@ -452,11 +451,10 @@ def make_request( given content. Args: - method (bytes/unicode): The HTTP request method ("verb"). - path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. - escaped UTF-8 & spaces and such). - content (bytes or dict): The body of the request. JSON-encoded, if - a dict. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces + and such). content (bytes or dict): The body of the request. + JSON-encoded, if a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py index 80b97167bac0..9266f12590cb 100644 --- a/tests/util/caches/test_cached_call.py +++ b/tests/util/caches/test_cached_call.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import NoReturn from unittest.mock import Mock from twisted.internet import defer @@ -23,14 +24,14 @@ class CachedCallTestCase(TestCase): - def test_get(self): + def test_get(self) -> None: """ Happy-path test case: makes a couple of calls and makes sure they behave correctly """ - d = Deferred() + d: "Deferred[int]" = Deferred() - async def f(): + async def f() -> int: return await d slow_call = Mock(side_effect=f) @@ -43,7 +44,7 @@ async def f(): # now fire off a couple of calls completed_results = [] - async def r(): + async def r() -> None: res = await cached_call.get() completed_results.append(res) @@ -69,12 +70,12 @@ async def r(): self.assertEqual(r3, 123) slow_call.assert_not_called() - def test_fast_call(self): + def test_fast_call(self) -> None: """ Test the behaviour when the underlying function completes immediately """ - async def f(): + async def f() -> int: return 12 fast_call = Mock(side_effect=f) @@ -92,12 +93,12 @@ async def f(): class RetryOnExceptionCachedCallTestCase(TestCase): - def test_get(self): + def test_get(self) -> None: # set up the RetryOnExceptionCachedCall around a function which will fail # (after a while) - d = Deferred() + d: "Deferred[int]" = Deferred() - async def f1(): + async def f1() -> NoReturn: await d raise ValueError("moo") @@ -110,7 +111,7 @@ async def f1(): # now fire off a couple of calls completed_results = [] - async def r(): + async def r() -> None: try: await cached_call.get() except Exception as e1: @@ -137,7 +138,7 @@ async def r(): # to the getter d = Deferred() - async def f2(): + async def f2() -> int: return await d slow_call.reset_mock() diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 02b99b466a26..f74d82b1dcd5 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +from typing import List, Tuple from twisted.internet import defer @@ -22,20 +23,20 @@ class DeferredCacheTestCase(TestCase): - def test_empty(self): - cache = DeferredCache("test") + def test_empty(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") with self.assertRaises(KeyError): cache.get("foo") - def test_hit(self): - cache = DeferredCache("test") + def test_hit(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") cache.prefill("foo", 123) self.assertEqual(self.successResultOf(cache.get("foo")), 123) - def test_hit_deferred(self): - cache = DeferredCache("test") - origin_d = defer.Deferred() + def test_hit_deferred(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") + origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d) # get should return an incomplete deferred @@ -43,7 +44,7 @@ def test_hit_deferred(self): self.assertFalse(get_d.called) # add a callback that will make sure that the set_d gets called before the get_d - def check1(r): + def check1(r: str) -> str: self.assertTrue(set_d.called) return r @@ -55,16 +56,16 @@ def check1(r): self.assertEqual(self.successResultOf(set_d), 99) self.assertEqual(self.successResultOf(get_d), 99) - def test_callbacks(self): + def test_callbacks(self) -> None: """Invalidation callbacks are called at the right time""" - cache = DeferredCache("test") + cache: DeferredCache[str, int] = DeferredCache("test") callbacks = set() # start with an entry, with a callback cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) # now replace that entry with a pending result - origin_d = defer.Deferred() + origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) # ... and also make a get request @@ -89,15 +90,15 @@ def test_callbacks(self): cache.prefill("k1", 30) self.assertEqual(callbacks, {"set", "get"}) - def test_set_fail(self): - cache = DeferredCache("test") + def test_set_fail(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") callbacks = set() # start with an entry, with a callback cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) # now replace that entry with a pending result - origin_d = defer.Deferred() + origin_d: defer.Deferred = defer.Deferred() set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) # ... and also make a get request @@ -126,9 +127,9 @@ def test_set_fail(self): cache.prefill("k1", 30) self.assertEqual(callbacks, {"prefill", "get2"}) - def test_get_immediate(self): - cache = DeferredCache("test") - d1 = defer.Deferred() + def test_get_immediate(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") + d1: "defer.Deferred[int]" = defer.Deferred() cache.set("key1", d1) # get_immediate should return default @@ -142,27 +143,27 @@ def test_get_immediate(self): v = cache.get_immediate("key1", 1) self.assertEqual(v, 2) - def test_invalidate(self): - cache = DeferredCache("test") + def test_invalidate(self) -> None: + cache: DeferredCache[Tuple[str], int] = DeferredCache("test") cache.prefill(("foo",), 123) cache.invalidate(("foo",)) with self.assertRaises(KeyError): cache.get(("foo",)) - def test_invalidate_all(self): - cache = DeferredCache("testcache") + def test_invalidate_all(self) -> None: + cache: DeferredCache[str, str] = DeferredCache("testcache") callback_record = [False, False] - def record_callback(idx): + def record_callback(idx: int) -> None: callback_record[idx] = True # add a couple of pending entries - d1 = defer.Deferred() + d1: "defer.Deferred[str]" = defer.Deferred() cache.set("key1", d1, partial(record_callback, 0)) - d2 = defer.Deferred() + d2: "defer.Deferred[str]" = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) # lookup should return pending deferreds @@ -193,8 +194,8 @@ def record_callback(idx): with self.assertRaises(KeyError): cache.get("key1", None) - def test_eviction(self): - cache = DeferredCache( + def test_eviction(self) -> None: + cache: DeferredCache[int, str] = DeferredCache( "test", max_entries=2, apply_cache_factor_from_config=False ) @@ -208,8 +209,8 @@ def test_eviction(self): cache.get(2) cache.get(3) - def test_eviction_lru(self): - cache = DeferredCache( + def test_eviction_lru(self) -> None: + cache: DeferredCache[int, str] = DeferredCache( "test", max_entries=2, apply_cache_factor_from_config=False ) @@ -227,8 +228,8 @@ def test_eviction_lru(self): cache.get(1) cache.get(3) - def test_eviction_iterable(self): - cache = DeferredCache( + def test_eviction_iterable(self) -> None: + cache: DeferredCache[int, List[str]] = DeferredCache( "test", max_entries=3, apply_cache_factor_from_config=False, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 43475a307f9b..13f1edd5332c 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable, Set, Tuple +from typing import Iterable, Set, Tuple, cast from unittest import mock from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.interfaces import IReactorTime from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -37,8 +38,8 @@ def run_on_reactor(): - d = defer.Deferred() - reactor.callLater(0, d.callback, 0) + d: "Deferred[int]" = defer.Deferred() + cast(IReactorTime, reactor).callLater(0, d.callback, 0) return make_deferred_yieldable(d) @@ -224,7 +225,8 @@ def fn(self, arg1): callbacks: Set[str] = set() # set off an asynchronous request - obj.result = origin_d = defer.Deferred() + origin_d: Deferred = defer.Deferred() + obj.result = origin_d d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) self.assertFalse(d1.called) @@ -262,7 +264,7 @@ def test_cache_logcontexts(self): """Check that logcontexts are set and restored correctly when using the cache.""" - complete_lookup = defer.Deferred() + complete_lookup: Deferred = defer.Deferred() class Cls: @descriptors.cached() @@ -772,10 +774,14 @@ def fn(self, arg1, arg2): @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): - assert current_context().name == "c1" + context = current_context() + assert isinstance(context, LoggingContext) + assert context.name == "c1" # we want this to behave like an asynchronous function await run_on_reactor() - assert current_context().name == "c1" + context = current_context() + assert isinstance(context, LoggingContext) + assert context.name == "c1" return self.mock(args1, arg2) with LoggingContext("c1") as c1: @@ -834,7 +840,7 @@ def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) obj = Cls() - deferred_result = Deferred() + deferred_result: "Deferred[dict]" = Deferred() obj.mock.return_value = deferred_result # start off several concurrent lookups of the same key diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index 025b73e32f90..f09eeecadabe 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase): (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) """ - def setUp(self): + def setUp(self) -> None: self.reactor, self.clock = get_clock() def with_cache(self, name: str, ms: int = 0) -> ResponseCache: @@ -49,7 +49,7 @@ async def delayed_return(self, o: str) -> str: await self.clock.sleep(1) return o - def test_cache_hit(self): + def test_cache_hit(self) -> None: cache = self.with_cache("keeping_cache", ms=9001) expected_result = "howdy" @@ -74,7 +74,7 @@ def test_cache_hit(self): "cache should still have the result", ) - def test_cache_miss(self): + def test_cache_miss(self) -> None: cache = self.with_cache("trashing_cache", ms=0) expected_result = "howdy" @@ -90,7 +90,7 @@ def test_cache_miss(self): ) self.assertCountEqual([], cache.keys(), "cache should not have the result now") - def test_cache_expire(self): + def test_cache_expire(self) -> None: cache = self.with_cache("short_cache", ms=1000) expected_result = "howdy" @@ -115,7 +115,7 @@ def test_cache_expire(self): self.reactor.pump((2,)) self.assertCountEqual([], cache.keys(), "cache should not have the result now") - def test_cache_wait_hit(self): + def test_cache_wait_hit(self) -> None: cache = self.with_cache("neutral_cache") expected_result = "howdy" @@ -131,7 +131,7 @@ def test_cache_wait_hit(self): self.assertEqual(expected_result, self.successResultOf(wrap_d)) - def test_cache_wait_expire(self): + def test_cache_wait_expire(self) -> None: cache = self.with_cache("medium_cache", ms=3000) expected_result = "howdy" @@ -162,7 +162,7 @@ def test_cache_wait_expire(self): self.assertCountEqual([], cache.keys(), "cache should not have the result now") @parameterized.expand([(True,), (False,)]) - def test_cache_context_nocache(self, should_cache: bool): + def test_cache_context_nocache(self, should_cache: bool) -> None: """If the callback clears the should_cache bit, the result should not be cached""" cache = self.with_cache("medium_cache", ms=3000) @@ -170,7 +170,7 @@ def test_cache_context_nocache(self, should_cache: bool): call_count = 0 - async def non_caching(o: str, cache_context: ResponseCacheContext[int]): + async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str: nonlocal call_count call_count += 1 await self.clock.sleep(1) diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index fe8314057da1..679d1eb36bd9 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -20,11 +20,11 @@ class CacheTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.mock_timer = Mock(side_effect=lambda: 100.0) - self.cache = TTLCache("test_cache", self.mock_timer) + self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer) - def test_get(self): + def test_get(self) -> None: """simple set/get tests""" self.cache.set("one", "1", 10) self.cache.set("two", "2", 20) @@ -59,7 +59,7 @@ def test_get(self): self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.misses, 5) - def test_expiry(self): + def test_expiry(self) -> None: self.cache.set("one", "1", 10) self.cache.set("two", "2", 20) self.cache.set("three", "3", 30)