diff --git a/.buildkite/.env b/.buildkite/.env deleted file mode 100644 index 85b102d07fff..000000000000 --- a/.buildkite/.env +++ /dev/null @@ -1,13 +0,0 @@ -CI -BUILDKITE -BUILDKITE_BUILD_NUMBER -BUILDKITE_BRANCH -BUILDKITE_BUILD_NUMBER -BUILDKITE_JOB_ID -BUILDKITE_BUILD_URL -BUILDKITE_PROJECT_SLUG -BUILDKITE_COMMIT -BUILDKITE_PULL_REQUEST -BUILDKITE_TAG -CODECOV_TOKEN -TRIAL_FLAGS diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh deleted file mode 100755 index 361440fd1a1c..000000000000 --- a/.buildkite/merge_base_branch.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash - -set -e - -if [[ "$BUILDKITE_BRANCH" =~ ^(develop|master|dinsic|shhs|release-.*)$ ]]; then - echo "Not merging forward, as this is a release branch" - exit 0 -fi - -if [[ -z $BUILDKITE_PULL_REQUEST_BASE_BRANCH ]]; then - echo "Not a pull request, or hasn't had a PR opened yet..." - - # It probably hasn't had a PR opened yet. Since all PRs land on develop, we - # can probably assume it's based on it and will be merged into it. - GITBASE="develop" -else - # Get the reference, using the GitHub API - GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH -fi - -echo "--- merge_base_branch $GITBASE" - -# Show what we are before -git --no-pager show -s - -# Set up username so it can do a merge -git config --global user.email bot@matrix.org -git config --global user.name "A robot" - -# Fetch and merge. If it doesn't work, it will raise due to set -e. -git fetch -u origin $GITBASE -git merge --no-edit --no-commit origin/$GITBASE - -# Show what we are after. -git --no-pager show -s diff --git a/.buildkite/postgres-config.yaml b/.ci/postgres-config.yaml similarity index 86% rename from .buildkite/postgres-config.yaml rename to .ci/postgres-config.yaml index 67e17fa9d1de..f5a4aecd51ca 100644 --- a/.buildkite/postgres-config.yaml +++ b/.ci/postgres-config.yaml @@ -3,7 +3,7 @@ # CI's Docker setup at the point where this file is considered. server_name: "localhost:8800" -signing_key_path: ".buildkite/test.signing.key" +signing_key_path: ".ci/test.signing.key" report_stats: false @@ -11,7 +11,7 @@ database: name: "psycopg2" args: user: postgres - host: postgres + host: localhost password: postgres database: synapse diff --git a/.buildkite/scripts/postgres_exec.py b/.ci/scripts/postgres_exec.py similarity index 92% rename from .buildkite/scripts/postgres_exec.py rename to .ci/scripts/postgres_exec.py index 086b391724ec..0f39a336d52d 100755 --- a/.buildkite/scripts/postgres_exec.py +++ b/.ci/scripts/postgres_exec.py @@ -23,7 +23,7 @@ # We use "postgres" as a database because it's bound to exist and the "synapse" one # doesn't exist yet. db_conn = psycopg2.connect( - user="postgres", host="postgres", password="postgres", dbname="postgres" + user="postgres", host="localhost", password="postgres", dbname="postgres" ) db_conn.autocommit = True cur = db_conn.cursor() diff --git a/.buildkite/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh similarity index 81% rename from .buildkite/scripts/test_old_deps.sh rename to .ci/scripts/test_old_deps.sh index 9270d55f0461..8b473936f8c3 100755 --- a/.buildkite/scripts/test_old_deps.sh +++ b/.ci/scripts/test_old_deps.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -# this script is run by buildkite in a plain `bionic` container; it installs the +# this script is run by GitHub Actions in a plain `bionic` container; it installs the # minimal requirements for tox and hands over to the py3-old tox environment. set -ex diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.ci/scripts/test_synapse_port_db.sh similarity index 60% rename from .buildkite/scripts/test_synapse_port_db.sh rename to .ci/scripts/test_synapse_port_db.sh index 82d7d56d4e9f..2b4e5ec1707d 100755 --- a/.buildkite/scripts/test_synapse_port_db.sh +++ b/.ci/scripts/test_synapse_port_db.sh @@ -20,22 +20,22 @@ pip install -e . echo "--- Generate the signing key" # Generate the server's signing key. -python -m synapse.app.homeserver --generate-keys -c .buildkite/sqlite-config.yaml +python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml echo "--- Prepare test database" # Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml +scripts-dev/update_database --database-config .ci/sqlite-config.yaml # Create the PostgreSQL database. -./.buildkite/scripts/postgres_exec.py "CREATE DATABASE synapse" +.ci/scripts/postgres_exec.py "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against test database" -coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml +coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml # We should be able to run twice against the same database. echo "+++ Run synapse_port_db a second time" -coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml +coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml ##### @@ -44,14 +44,14 @@ coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --p echo "--- Prepare empty SQLite database" # we do this by deleting the sqlite db, and then doing the same again. -rm .buildkite/test_db.db +rm .ci/test_db.db -scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml +scripts-dev/update_database --database-config .ci/sqlite-config.yaml # re-create the PostgreSQL database. -./.buildkite/scripts/postgres_exec.py \ +.ci/scripts/postgres_exec.py \ "DROP DATABASE synapse" \ "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against empty database" -coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml +coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml diff --git a/.buildkite/sqlite-config.yaml b/.ci/sqlite-config.yaml similarity index 80% rename from .buildkite/sqlite-config.yaml rename to .ci/sqlite-config.yaml index d16459cfd947..3373743da3cd 100644 --- a/.buildkite/sqlite-config.yaml +++ b/.ci/sqlite-config.yaml @@ -3,14 +3,14 @@ # schema and run background updates on it. server_name: "localhost:8800" -signing_key_path: ".buildkite/test.signing.key" +signing_key_path: ".ci/test.signing.key" report_stats: false database: name: "sqlite3" args: - database: ".buildkite/test_db.db" + database: ".ci/test_db.db" # Suppress the key server warning. trusted_key_servers: [] diff --git a/.buildkite/test_db.db b/.ci/test_db.db similarity index 100% rename from .buildkite/test_db.db rename to .ci/test_db.db diff --git a/.buildkite/worker-blacklist b/.ci/worker-blacklist similarity index 100% rename from .buildkite/worker-blacklist rename to .ci/worker-blacklist diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 239553ae138e..8736699ad8cf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - + jobs: lint: runs-on: ubuntu-latest @@ -38,20 +38,15 @@ jobs: if: ${{ github.base_ref == 'develop' || contains(github.base_ref, 'release-') }} runs-on: ubuntu-latest steps: - # Note: This and the script can be simplified once we drop Buildkite. See: - # https://github.com/actions/checkout/issues/266#issuecomment-638346893 - # https://github.com/actions/checkout/issues/416 - uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - uses: actions/setup-python@v2 - run: pip install tox - - name: Patch Buildkite-specific test script - run: | - sed -i -e 's/\$BUILDKITE_PULL_REQUEST/${{ github.event.number }}/' \ - scripts-dev/check-newsfragment - run: scripts-dev/check-newsfragment + env: + PULL_REQUEST_NUMBER: ${{ github.event.number }} lint-sdist: runs-on: ubuntu-latest @@ -144,7 +139,7 @@ jobs: uses: docker://ubuntu:bionic # For old python and sqlite with: workdir: /github/workspace - entrypoint: .buildkite/scripts/test_old_deps.sh + entrypoint: .ci/scripts/test_old_deps.sh env: TRIAL_FLAGS: "--jobs=2" - name: Dump logs @@ -197,12 +192,12 @@ jobs: volumes: - ${{ github.workspace }}:/src env: - BUILDKITE_BRANCH: ${{ github.head_ref }} POSTGRES: ${{ matrix.postgres && 1}} MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}} WORKERS: ${{ matrix.workers && 1 }} REDIS: ${{ matrix.redis && 1 }} BLACKLIST: ${{ matrix.workers && 'synapse-blacklist-with-workers' }} + TOP: ${{ github.workspace }} strategy: fail-fast: false @@ -232,7 +227,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Prepare test blacklist - run: cat sytest-blacklist .buildkite/worker-blacklist > synapse-blacklist-with-workers + run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers - name: Run SyTest run: /bootstrap.sh synapse working-directory: /src @@ -252,6 +247,8 @@ jobs: if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail needs: linting-done runs-on: ubuntu-latest + env: + TOP: ${{ github.workspace }} strategy: matrix: include: @@ -281,13 +278,7 @@ jobs: - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Patch Buildkite-specific test scripts - run: | - sed -i -e 's/host="postgres"/host="localhost"/' .buildkite/scripts/postgres_exec.py - sed -i -e 's/host: postgres/host: localhost/' .buildkite/postgres-config.yaml - sed -i -e 's|/src/||' .buildkite/{sqlite,postgres}-config.yaml - sed -i -e 's/\$TOP/\$GITHUB_WORKSPACE/' .coveragerc - - run: .buildkite/scripts/test_synapse_port_db.sh + - run: .ci/scripts/test_synapse_port_db.sh complement: if: ${{ !failure() && !cancelled() }} @@ -374,6 +365,11 @@ jobs: rc=0 results=$(jq -r 'to_entries[] | [.key,.value.result] | join(" ")' <<< $NEEDS_CONTEXT) while read job result ; do + # The newsfile lint may be skipped on non PR builds + if [ $result == "skipped" ] && [ $job == "lint-newsfile" ]; then + continue + fi + if [ "$result" != "success" ]; then echo "::set-failed ::Job $job returned $result" rc=1 diff --git a/CHANGES.md b/CHANGES.md index 0e5e052951a9..f8da8771aa6e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,96 @@ +Synapse 1.41.0 (2021-08-24) +=========================== + +This release adds support for Debian 12 (Bookworm), but **removes support for Ubuntu 20.10 (Groovy Gorilla)**, which reached End of Life last month. + +Note that when using workers the `/_synapse/admin/v1/users/{userId}/media` must now be handled by media workers. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. + + +Features +-------- + +- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version when creating restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) + + +Synapse 1.41.0rc1 (2021-08-18) +============================== + +Features +-------- + +- Add `get_userinfo_by_id` method to ModuleApi. ([\#9581](https://github.com/matrix-org/synapse/issues/9581)) +- Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. ([\#10394](https://github.com/matrix-org/synapse/issues/10394)) +- Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. ([\#10435](https://github.com/matrix-org/synapse/issues/10435)) +- Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)). ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) +- Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10498](https://github.com/matrix-org/synapse/issues/10498)) +- Add a configuration setting for the time a `/sync` response is cached for. ([\#10513](https://github.com/matrix-org/synapse/issues/10513)) +- The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. ([\#10518](https://github.com/matrix-org/synapse/issues/10518)) +- Add support for new redaction rules for historical events specified in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#10538](https://github.com/matrix-org/synapse/issues/10538)) +- Add a setting to disable TLS when sending email. ([\#10546](https://github.com/matrix-org/synapse/issues/10546)) +- Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10549](https://github.com/matrix-org/synapse/issues/10549), [\#10560](https://github.com/matrix-org/synapse/issues/10560), [\#10569](https://github.com/matrix-org/synapse/issues/10569), [\#10574](https://github.com/matrix-org/synapse/issues/10574), [\#10575](https://github.com/matrix-org/synapse/issues/10575), [\#10579](https://github.com/matrix-org/synapse/issues/10579), [\#10583](https://github.com/matrix-org/synapse/issues/10583)) +- Admin API to delete several media for a specific user. Contributed by @dklimpel. ([\#10558](https://github.com/matrix-org/synapse/issues/10558), [\#10628](https://github.com/matrix-org/synapse/issues/10628)) +- Add support for routing `/createRoom` to workers. ([\#10564](https://github.com/matrix-org/synapse/issues/10564)) +- Update the Synapse Grafana dashboard. ([\#10570](https://github.com/matrix-org/synapse/issues/10570)) +- Add an admin API (`GET /_synapse/admin/username_available`) to check if a username is available (regardless of registration settings). ([\#10578](https://github.com/matrix-org/synapse/issues/10578)) +- Allow editing a user's `external_ids` via the "Edit User" admin API. Contributed by @dklimpel. ([\#10598](https://github.com/matrix-org/synapse/issues/10598)) +- The Synapse manhole no longer needs coroutines to be wrapped in `defer.ensureDeferred`. ([\#10602](https://github.com/matrix-org/synapse/issues/10602)) +- Add option to allow modules to run periodic tasks on all instances, rather than just the one configured to run background tasks. ([\#10638](https://github.com/matrix-org/synapse/issues/10638)) + + +Bugfixes +-------- + +- Add some clarification to the sample config file. Contributed by @Kentokamoto. ([\#10129](https://github.com/matrix-org/synapse/issues/10129)) +- Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. ([\#10532](https://github.com/matrix-org/synapse/issues/10532)) +- Fix exceptions in logs when failing to get remote room list. ([\#10541](https://github.com/matrix-org/synapse/issues/10541)) +- Fix longstanding bug which caused the user's presence "status message" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) +- Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10580](https://github.com/matrix-org/synapse/issues/10580)) +- Fix a bug introduced in v1.37.1 where an error could occur in the asynchronous processing of PDUs when the queue was empty. ([\#10592](https://github.com/matrix-org/synapse/issues/10592)) +- Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. ([\#10606](https://github.com/matrix-org/synapse/issues/10606)) +- Additional validation for the spaces summary API to avoid errors like `ValueError: Stop argument for islice() must be None or an integer`. The missing validation has existed since v1.31.0. ([\#10611](https://github.com/matrix-org/synapse/issues/10611)) +- Revert behaviour introduced in v1.38.0 that strips `org.matrix.msc2732.device_unused_fallback_key_types` from `/sync` when its value is empty. This field should instead always be present according to [MSC2732](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2732-olm-fallback-keys.md). ([\#10623](https://github.com/matrix-org/synapse/issues/10623)) + + +Improved Documentation +---------------------- + +- Add documentation for configuring a forward proxy. ([\#10443](https://github.com/matrix-org/synapse/issues/10443)) +- Updated the reverse proxy documentation to highlight the homserver configuration that is needed to make Synapse aware that is is intentionally reverse proxied. ([\#10551](https://github.com/matrix-org/synapse/issues/10551)) +- Update CONTRIBUTING.md to fix index links and the instructions for SyTest in docker. ([\#10599](https://github.com/matrix-org/synapse/issues/10599)) + + +Deprecations and Removals +------------------------- + +- No longer build `.deb` packages for Ubuntu 20.10 Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) +- The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)) + + +Internal Changes +---------------- + +- Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. ([\#10119](https://github.com/matrix-org/synapse/issues/10119)) +- Reduce errors in PostgreSQL logs due to concurrent serialization errors. ([\#10504](https://github.com/matrix-org/synapse/issues/10504)) +- Include room ID in ignored EDU log messages. Contributed by @ilmari. ([\#10507](https://github.com/matrix-org/synapse/issues/10507)) +- Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10527](https://github.com/matrix-org/synapse/issues/10527), [\#10530](https://github.com/matrix-org/synapse/issues/10530)) +- Fix CI to not break when run against branches rather than pull requests. ([\#10529](https://github.com/matrix-org/synapse/issues/10529)) +- Mark all events stemming from the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint as historical. ([\#10537](https://github.com/matrix-org/synapse/issues/10537)) +- Clean up some of the federation event authentication code for clarity. ([\#10539](https://github.com/matrix-org/synapse/issues/10539), [\#10591](https://github.com/matrix-org/synapse/issues/10591)) +- Convert `Transaction` and `Edu` objects to attrs. ([\#10542](https://github.com/matrix-org/synapse/issues/10542)) +- Update `/batch_send` endpoint to only return `state_events` created by the `state_events_from_before` passed in. ([\#10552](https://github.com/matrix-org/synapse/issues/10552)) +- Update contributing.md to warn against rebasing an open PR. ([\#10563](https://github.com/matrix-org/synapse/issues/10563)) +- Remove the unused public rooms replication stream. ([\#10565](https://github.com/matrix-org/synapse/issues/10565)) +- Clarify error message when failing to join a restricted room. ([\#10572](https://github.com/matrix-org/synapse/issues/10572)) +- Remove references to BuildKite in favour of GitHub Actions. ([\#10573](https://github.com/matrix-org/synapse/issues/10573)) +- Move `/batch_send` endpoint defined by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) to the `/v2_alpha` directory. ([\#10576](https://github.com/matrix-org/synapse/issues/10576)) +- Allow multiple custom directories in `read_templates`. ([\#10587](https://github.com/matrix-org/synapse/issues/10587)) +- Re-organize the `synapse.federation.transport.server` module to create smaller files. ([\#10590](https://github.com/matrix-org/synapse/issues/10590)) +- Flatten the `synapse.rest.client` package by moving the contents of `v1` and `v2_alpha` into the parent. ([\#10600](https://github.com/matrix-org/synapse/issues/10600)) +- Build Debian packages for Debian 12 (Bookworm). ([\#10612](https://github.com/matrix-org/synapse/issues/10612)) +- Fix up a couple of links to the database schema documentation. ([\#10620](https://github.com/matrix-org/synapse/issues/10620)) +- Fix a broken link to the upgrade notes. ([\#10631](https://github.com/matrix-org/synapse/issues/10631)) + + Synapse 1.40.0 (2021-08-10) =========================== diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e7eef23419d5..cd6c34df85b1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,8 +13,9 @@ This document aims to get you started with contributing to this repo! - [7. Turn coffee and documentation into code and documentation!](#7-turn-coffee-and-documentation-into-code-and-documentation) - [8. Test, test, test!](#8-test-test-test) * [Run the linters.](#run-the-linters) - * [Run the unit tests.](#run-the-unit-tests) - * [Run the integration tests.](#run-the-integration-tests) + * [Run the unit tests.](#run-the-unit-tests-twisted-trial) + * [Run the integration tests (SyTest).](#run-the-integration-tests-sytest) + * [Run the integration tests (Complement).](#run-the-integration-tests-complement) - [9. Submit your patch.](#9-submit-your-patch) * [Changelog](#changelog) + [How do I know what to call the changelog file before I create the PR?](#how-do-i-know-what-to-call-the-changelog-file-before-i-create-the-pr) @@ -197,7 +198,7 @@ The following command will let you run the integration test with the most common configuration: ```sh -$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:py37 +$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:buster ``` This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). @@ -252,6 +253,7 @@ To prepare a Pull Request, please: 4. on GitHub, [create the Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request); 5. add a [changelog entry](#changelog) and push it to your Pull Request; 6. for most contributors, that's all - however, if you are a member of the organization `matrix-org`, on GitHub, please request a review from `matrix.org / Synapse Core`. +7. if you need to update your PR, please avoid rebasing and just add new commits to your branch. ## Changelog diff --git a/MANIFEST.in b/MANIFEST.in index 0522319c4000..44d5cc761816 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -44,9 +44,9 @@ include book.toml include pyproject.toml recursive-include changelog.d * -prune .buildkite prune .circleci prune .github +prune .ci prune contrib prune debian prune demo/etc diff --git a/UPGRADE.rst b/UPGRADE.rst index 17ecd935fdbb..6c7f9cb18e9f 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -1,7 +1,7 @@ Upgrading Synapse ================= -This document has moved to the `Synapse documentation website `_. +This document has moved to the `Synapse documentation website `_. Please update your links. The markdown source is available in `docs/upgrade.md `_. diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index 0c4816b7cd53..ed1e8ba7f8b6 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -54,7 +54,7 @@ "gnetId": null, "graphTooltip": 0, "id": null, - "iteration": 1621258266004, + "iteration": 1628606819564, "links": [ { "asDropdown": false, @@ -307,7 +307,6 @@ ], "thresholds": [ { - "$$hashKey": "object:283", "colorMode": "warning", "fill": false, "line": true, @@ -316,7 +315,6 @@ "yaxis": "left" }, { - "$$hashKey": "object:284", "colorMode": "critical", "fill": false, "line": true, @@ -344,7 +342,6 @@ }, "yaxes": [ { - "$$hashKey": "object:255", "decimals": null, "format": "s", "label": "", @@ -354,7 +351,6 @@ "show": true }, { - "$$hashKey": "object:256", "format": "hertz", "label": "", "logBase": 1, @@ -429,7 +425,6 @@ ], "thresholds": [ { - "$$hashKey": "object:566", "colorMode": "critical", "fill": true, "line": true, @@ -457,7 +452,6 @@ }, "yaxes": [ { - "$$hashKey": "object:538", "decimals": null, "format": "percentunit", "label": null, @@ -467,7 +461,6 @@ "show": true }, { - "$$hashKey": "object:539", "format": "short", "label": null, "logBase": 1, @@ -573,7 +566,6 @@ }, "yaxes": [ { - "$$hashKey": "object:1560", "format": "bytes", "logBase": 1, "max": null, @@ -581,7 +573,6 @@ "show": true }, { - "$$hashKey": "object:1561", "format": "short", "logBase": 1, "max": null, @@ -641,7 +632,6 @@ "renderer": "flot", "seriesOverrides": [ { - "$$hashKey": "object:639", "alias": "/max$/", "color": "#890F02", "fill": 0, @@ -693,7 +683,6 @@ }, "yaxes": [ { - "$$hashKey": "object:650", "decimals": null, "format": "none", "label": "", @@ -703,7 +692,6 @@ "show": true }, { - "$$hashKey": "object:651", "decimals": null, "format": "short", "label": null, @@ -783,11 +771,9 @@ "renderer": "flot", "seriesOverrides": [ { - "$$hashKey": "object:1240", "alias": "/user/" }, { - "$$hashKey": "object:1241", "alias": "/system/" } ], @@ -817,7 +803,6 @@ ], "thresholds": [ { - "$$hashKey": "object:1278", "colorMode": "custom", "fillColor": "rgba(255, 255, 255, 1)", "line": true, @@ -827,7 +812,6 @@ "yaxis": "left" }, { - "$$hashKey": "object:1279", "colorMode": "custom", "fillColor": "rgba(255, 255, 255, 1)", "line": true, @@ -837,7 +821,6 @@ "yaxis": "left" }, { - "$$hashKey": "object:1498", "colorMode": "critical", "fill": true, "line": true, @@ -865,7 +848,6 @@ }, "yaxes": [ { - "$$hashKey": "object:1250", "decimals": null, "format": "percentunit", "label": "", @@ -875,7 +857,6 @@ "show": true }, { - "$$hashKey": "object:1251", "format": "short", "logBase": 1, "max": null, @@ -1427,7 +1408,6 @@ }, "yaxes": [ { - "$$hashKey": "object:572", "format": "percentunit", "label": null, "logBase": 1, @@ -1436,7 +1416,6 @@ "show": true }, { - "$$hashKey": "object:573", "format": "short", "label": null, "logBase": 1, @@ -1720,7 +1699,6 @@ }, "yaxes": [ { - "$$hashKey": "object:102", "format": "hertz", "logBase": 1, "max": null, @@ -1728,7 +1706,6 @@ "show": true }, { - "$$hashKey": "object:103", "format": "short", "logBase": 1, "max": null, @@ -3425,7 +3402,7 @@ "h": 9, "w": 12, "x": 0, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 79, @@ -3442,9 +3419,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3526,7 +3506,7 @@ "h": 9, "w": 12, "x": 12, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 83, @@ -3543,9 +3523,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3629,7 +3612,7 @@ "h": 9, "w": 12, "x": 0, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 109, @@ -3646,9 +3629,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3733,7 +3719,7 @@ "h": 9, "w": 12, "x": 12, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 111, @@ -3750,9 +3736,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3831,7 +3820,7 @@ "h": 8, "w": 12, "x": 0, - "y": 51 + "y": 24 }, "hiddenSeries": false, "id": 142, @@ -3847,8 +3836,11 @@ "lines": true, "linewidth": 1, "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 2, "points": false, "renderer": "flot", @@ -3931,7 +3923,7 @@ "h": 9, "w": 12, "x": 12, - "y": 51 + "y": 24 }, "hiddenSeries": false, "id": 140, @@ -3948,9 +3940,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -4079,7 +4074,7 @@ "h": 9, "w": 12, "x": 0, - "y": 59 + "y": 32 }, "heatmap": {}, "hideZeroBuckets": false, @@ -4145,7 +4140,7 @@ "h": 9, "w": 12, "x": 12, - "y": 60 + "y": 33 }, "hiddenSeries": false, "id": 162, @@ -4163,9 +4158,12 @@ "linewidth": 0, "links": [], "nullPointMode": "connected", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -4350,7 +4348,7 @@ "h": 9, "w": 12, "x": 0, - "y": 68 + "y": 41 }, "heatmap": {}, "hideZeroBuckets": false, @@ -4396,6 +4394,311 @@ "yBucketBound": "auto", "yBucketNumber": null, "yBucketSize": null + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "editable": true, + "error": false, + "fieldConfig": { + "defaults": { + "custom": {}, + "links": [] + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "grid": {}, + "gridPos": { + "h": 9, + "w": 12, + "x": 12, + "y": 42 + }, + "hiddenSeries": false, + "id": 203, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_server_oldest_inbound_pdu_in_staging{job=\"$job\",index=~\"$index\",instance=\"$instance\"}", + "format": "time_series", + "interval": "", + "intervalFactor": 1, + "legendFormat": "rss {{index}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Age of oldest event in staging area", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ms", + "label": null, + "logBase": 1, + "max": null, + "min": 0, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "editable": true, + "error": false, + "fieldConfig": { + "defaults": { + "custom": {}, + "links": [] + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "grid": {}, + "gridPos": { + "h": 9, + "w": 12, + "x": 0, + "y": 50 + }, + "hiddenSeries": false, + "id": 202, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_server_number_inbound_pdu_in_staging{job=\"$job\",index=~\"$index\",instance=\"$instance\"}", + "format": "time_series", + "interval": "", + "intervalFactor": 1, + "legendFormat": "rss {{index}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Number of events in federation staging area", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "none", + "label": null, + "logBase": 1, + "max": null, + "min": 0, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 51 + }, + "hiddenSeries": false, + "id": 205, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(synapse_federation_soft_failed_events_total{instance=\"$instance\"}[$bucket_size]))", + "interval": "", + "legendFormat": "soft-failed events", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Soft-failed event rate", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "hertz", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": false + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "title": "Federation", @@ -4647,7 +4950,7 @@ "h": 7, "w": 12, "x": 0, - "y": 8 + "y": 33 }, "hiddenSeries": false, "id": 48, @@ -4749,7 +5052,7 @@ "h": 7, "w": 12, "x": 12, - "y": 8 + "y": 33 }, "hiddenSeries": false, "id": 104, @@ -4877,7 +5180,7 @@ "h": 7, "w": 12, "x": 0, - "y": 15 + "y": 40 }, "hiddenSeries": false, "id": 10, @@ -4981,7 +5284,7 @@ "h": 7, "w": 12, "x": 12, - "y": 15 + "y": 40 }, "hiddenSeries": false, "id": 11, @@ -5086,7 +5389,7 @@ "h": 7, "w": 12, "x": 0, - "y": 22 + "y": 47 }, "hiddenSeries": false, "id": 180, @@ -5168,6 +5471,126 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "fieldConfig": { + "defaults": { + "custom": {}, + "links": [] + }, + "overrides": [] + }, + "fill": 6, + "fillGradient": 0, + "gridPos": { + "h": 9, + "w": 12, + "x": 12, + "y": 47 + }, + "hiddenSeries": false, + "id": 200, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "histogram_quantile(0.99, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "99%", + "refId": "D" + }, + { + "expr": "histogram_quantile(0.9, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "90%", + "refId": "A" + }, + { + "expr": "histogram_quantile(0.75, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "75%", + "refId": "C" + }, + { + "expr": "histogram_quantile(0.5, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "50%", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Time waiting for DB connection quantiles", + "tooltip": { + "shared": true, + "sort": 2, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "decimals": null, + "format": "s", + "label": "", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": false + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, @@ -5916,7 +6339,7 @@ "h": 10, "w": 12, "x": 0, - "y": 84 + "y": 35 }, "hiddenSeries": false, "id": 1, @@ -6022,7 +6445,7 @@ "h": 10, "w": 12, "x": 12, - "y": 84 + "y": 35 }, "hiddenSeries": false, "id": 8, @@ -6126,7 +6549,7 @@ "h": 10, "w": 12, "x": 0, - "y": 94 + "y": 45 }, "hiddenSeries": false, "id": 38, @@ -6226,7 +6649,7 @@ "h": 10, "w": 12, "x": 12, - "y": 94 + "y": 45 }, "hiddenSeries": false, "id": 39, @@ -6258,8 +6681,9 @@ "steppedLine": false, "targets": [ { - "expr": "topk(10, rate(synapse_util_caches_cache:total{job=\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size]) - rate(synapse_util_caches_cache:hits{job=\"$job\",instance=\"$instance\"}[$bucket_size]))", + "expr": "topk(10, rate(synapse_util_caches_cache:total{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size]) - rate(synapse_util_caches_cache:hits{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 2, "legendFormat": "{{name}} {{job}}-{{index}}", "refId": "A", @@ -6326,7 +6750,7 @@ "h": 9, "w": 12, "x": 0, - "y": 104 + "y": 55 }, "hiddenSeries": false, "id": 65, @@ -9051,7 +9475,7 @@ "h": 8, "w": 12, "x": 0, - "y": 119 + "y": 41 }, "hiddenSeries": false, "id": 156, @@ -9089,7 +9513,7 @@ "steppedLine": false, "targets": [ { - "expr": "synapse_admin_mau:current{instance=\"$instance\"}", + "expr": "synapse_admin_mau:current{instance=\"$instance\", job=~\"$job\"}", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -9097,7 +9521,7 @@ "refId": "A" }, { - "expr": "synapse_admin_mau:max{instance=\"$instance\"}", + "expr": "synapse_admin_mau:max{instance=\"$instance\", job=~\"$job\"}", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -9164,7 +9588,7 @@ "h": 8, "w": 12, "x": 12, - "y": 119 + "y": 41 }, "hiddenSeries": false, "id": 160, @@ -9484,7 +9908,7 @@ "h": 8, "w": 12, "x": 0, - "y": 73 + "y": 43 }, "hiddenSeries": false, "id": 168, @@ -9516,7 +9940,7 @@ { "expr": "rate(synapse_appservice_api_sent_events{instance=\"$instance\"}[$bucket_size])", "interval": "", - "legendFormat": "{{exported_service}}", + "legendFormat": "{{service}}", "refId": "A" } ], @@ -9579,7 +10003,7 @@ "h": 8, "w": 12, "x": 12, - "y": 73 + "y": 43 }, "hiddenSeries": false, "id": 171, @@ -9611,7 +10035,7 @@ { "expr": "rate(synapse_appservice_api_sent_transactions{instance=\"$instance\"}[$bucket_size])", "interval": "", - "legendFormat": "{{exported_service}}", + "legendFormat": "{{service}}", "refId": "A" } ], @@ -9959,7 +10383,6 @@ }, "yaxes": [ { - "$$hashKey": "object:165", "format": "hertz", "label": null, "logBase": 1, @@ -9968,7 +10391,6 @@ "show": true }, { - "$$hashKey": "object:166", "format": "short", "label": null, "logBase": 1, @@ -10071,7 +10493,6 @@ }, "yaxes": [ { - "$$hashKey": "object:390", "format": "hertz", "label": null, "logBase": 1, @@ -10080,7 +10501,6 @@ "show": true }, { - "$$hashKey": "object:391", "format": "short", "label": null, "logBase": 1, @@ -10169,7 +10589,6 @@ }, "yaxes": [ { - "$$hashKey": "object:390", "format": "hertz", "label": null, "logBase": 1, @@ -10178,7 +10597,6 @@ "show": true }, { - "$$hashKey": "object:391", "format": "short", "label": null, "logBase": 1, @@ -10470,5 +10888,5 @@ "timezone": "", "title": "Synapse", "uid": "000000012", - "version": 90 + "version": 99 } \ No newline at end of file diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 68c86599536c..801ecb9086c6 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -100,3 +100,18 @@ esac # add a dependency on the right version of python to substvars. PYPKG=`basename $SNAKE` echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars + + +# add a couple of triggers. This is needed so that dh-virtualenv can rebuild +# the venv when the system python changes (see +# https://dh-virtualenv.readthedocs.io/en/latest/tutorial.html#step-2-set-up-packaging-for-your-project) +# +# we do it here rather than the more conventional way of just adding it to +# debian/matrix-synapse-py3.triggers, because we need to add a trigger on the +# right version of python. +cat >>"debian/.debhelper/generated/matrix-synapse-py3/triggers" < Tue, 24 Aug 2021 15:31:45 +0100 + +matrix-synapse-py3 (1.41.0~rc1) stable; urgency=medium + + * New synapse release 1.41.0~rc1. + + -- Synapse Packaging team Wed, 18 Aug 2021 15:52:00 +0100 + matrix-synapse-py3 (1.40.0) stable; urgency=medium * New synapse release 1.40.0. @@ -20,6 +32,8 @@ matrix-synapse-py3 (1.40.0~rc1) stable; urgency=medium [ Richard van der Hoff ] * Drop backwards-compatibility code that was required to support Ubuntu Xenial. + * Update package triggers so that the virtualenv is correctly rebuilt + when the system python is rebuilt, on recent Python versions. [ Synapse Packaging team ] * New synapse release 1.40.0~rc1. diff --git a/debian/matrix-synapse-py3.triggers b/debian/matrix-synapse-py3.triggers deleted file mode 100644 index f8c1fdb021c9..000000000000 --- a/debian/matrix-synapse-py3.triggers +++ /dev/null @@ -1,9 +0,0 @@ -# Register interest in Python interpreter changes and -# don't make the Python package dependent on the virtualenv package -# processing (noawait) -interest-noawait /usr/bin/python3.5 -interest-noawait /usr/bin/python3.6 -interest-noawait /usr/bin/python3.7 - -# Also provide a symbolic trigger for all dh-virtualenv packages -interest dh-virtualenv-interpreter-update diff --git a/docker/conf/log.config b/docker/conf/log.config index a99462692628..7a216a36a046 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -18,18 +18,31 @@ handlers: backupCount: 6 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # there will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 {% endif %} console: diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 10be12d63865..56e0141c2b3a 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -7,6 +7,7 @@ - [Installation](setup/installation.md) - [Using Postgres](postgres.md) - [Configuring a Reverse Proxy](reverse_proxy.md) + - [Configuring a Forward/Outbound Proxy](setup/forward_proxy.md) - [Configuring a Turn Server](turn-howto.md) - [Delegation](delegate.md) @@ -20,6 +21,7 @@ - [Homeserver Sample Config File](usage/configuration/homeserver_sample_config.md) - [Logging Sample Config File](usage/configuration/logging_sample_config.md) - [Structured Logging](structured_logging.md) + - [Templates](templates.md) - [User Authentication](usage/configuration/user_authentication/README.md) - [Single-Sign On]() - [OpenID Connect](openid.md) diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 61bed1e0d5d8..ea05bd6e4465 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -12,6 +12,7 @@ - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) + * [Delete media uploaded by a user](#delete-media-uploaded-by-a-user) - [Purge Remote Media API](#purge-remote-media-api) # Querying media @@ -47,7 +48,8 @@ The API returns a JSON body like the following: ## List all media uploaded by a user Listing all media that has been uploaded by a local user can be achieved through -the use of the [List media of a user](user_admin_api.md#list-media-of-a-user) +the use of the +[List media uploaded by a user](user_admin_api.md#list-media-uploaded-by-a-user) Admin API. # Quarantine media @@ -281,6 +283,11 @@ The following fields are returned in the JSON response body: * `deleted_media`: an array of strings - List of deleted `media_id` * `total`: integer - Total number of deleted `media_id` +## Delete media uploaded by a user + +You can find details of how to delete multiple media uploaded by a user in +[User Admin API](user_admin_api.md#delete-media-uploaded-by-a-user). + # Purge Remote Media API The purge remote media API allows server admins to purge old cached remote media. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 160899754ede..6a9335d6ecfc 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -81,6 +81,16 @@ with a body of: "address": "" } ], + "external_ids": [ + { + "auth_provider": "", + "external_id": "" + }, + { + "auth_provider": "", + "external_id": "" + } + ], "avatar_url": "", "admin": false, "deactivated": false @@ -90,26 +100,34 @@ with a body of: To use it, you will need to authenticate by providing an `access_token` for a server admin: [Admin API](../usage/administration/admin_api) +Returns HTTP status code: +- `201` - When a new user object was created. +- `200` - When a user was modified. + URL parameters: - `user_id`: fully-qualified user id: for example, `@user:server.com`. Body parameters: -- `password`, optional. If provided, the user's password is updated and all +- `password` - string, optional. If provided, the user's password is updated and all devices are logged out. - -- `displayname`, optional, defaults to the value of `user_id`. - -- `threepids`, optional, allows setting the third-party IDs (email, msisdn) +- `displayname` - string, optional, defaults to the value of `user_id`. +- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) + - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. + - `address` - string. Value of third-party ID. belonging to a user. - -- `avatar_url`, optional, must be a +- `external_ids` - array, optional. Allow setting the identifier of the external identity + provider for SSO (Single sign-on). Details in + [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) + section `sso` and `oidc_providers`. + - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` + in homeserver configuration. + - `external_id` - string, user ID in the external identity provider. +- `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). - -- `admin`, optional, defaults to `false`. - -- `deactivated`, optional. If unspecified, deactivation state will be left +- `admin` - bool, optional, defaults to `false`. +- `deactivated` - bool, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to `false` for new accounts. A user cannot be erased by deactivating with this API. For details on deactivating users see [Deactivate Account](#deactivate-account). @@ -443,8 +461,9 @@ The following fields are returned in the JSON response body: - `joined_rooms` - An array of `room_id`. - `total` - Number of rooms. +## User media -## List media of a user +### List media uploaded by a user Gets a list of all local media that a specific `user_id` has created. By default, the response is ordered by descending creation date and ascending media ID. The newest media is on top. You can change the order with parameters @@ -543,7 +562,6 @@ The following fields are returned in the JSON response body: - `media` - An array of objects, each containing information about a media. Media objects contain the following fields: - - `created_ts` - integer - Timestamp when the content was uploaded in ms. - `last_access_ts` - integer - Timestamp when the content was last accessed in ms. - `media_id` - string - The id used to refer to the media. @@ -551,13 +569,58 @@ The following fields are returned in the JSON response body: - `media_type` - string - The MIME-type of the media. - `quarantined_by` - string - The user ID that initiated the quarantine request for this media. - - `safe_from_quarantine` - bool - Status if this media is safe from quarantining. - `upload_name` - string - The name the media was uploaded with. - - `next_token`: integer - Indication for pagination. See above. - `total` - integer - Total number of media. +### Delete media uploaded by a user + +This API deletes the *local* media from the disk of your own server +that a specific `user_id` has created. This includes any local thumbnails. + +This API will not affect media that has been uploaded to external +media repositories (e.g https://github.com/turt2live/matrix-media-repo/). + +By default, the API deletes media ordered by descending creation date and ascending media ID. +The newest media is deleted first. You can change the order with parameters +`order_by` and `dir`. If no `limit` is set the API deletes `100` files per request. + +The API is: + +``` +DELETE /_synapse/admin/v1/users//media +``` + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: [Admin API](../usage/administration/admin_api) + +A response body like the following is returned: + +```json +{ + "deleted_media": [ + "abcdefghijklmnopqrstuvwx" + ], + "total": 1 +} +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` + +**Note**: There is no `next_token`. This is not useful for deleting media, because +after deleting media the remaining media have a new order. + +**Parameters** + +This API has the same parameters as +[List media uploaded by a user](#list-media-uploaded-by-a-user). +With the parameters you can for example limit the number of files to delete at once or +delete largest/smallest or newest/oldest files first. + ## Login as a user Get an access token that can be used to authenticate as that user. Useful for @@ -1013,3 +1076,22 @@ The following parameters should be set in the URL: - `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must be local. +### Check username availability + +Checks to see if a username is available, and valid, for the server. See [the client-server +API](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) +for more information. + +This endpoint will work even if registration is disabled on the server, unlike +`/_matrix/client/r0/register/available`. + +The API is: + +``` +POST /_synapse/admin/v1/username_availabile?username=$localpart +``` + +The request and response format is the same as the [/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API. + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: [Admin API](../usage/administration/admin_api) diff --git a/docs/manhole.md b/docs/manhole.md index 37d1d7823c00..db92df88dcc9 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -67,7 +67,7 @@ This gives a Python REPL in which `hs` gives access to the `synapse.server.HomeServer` object - which in turn gives access to many other parts of the process. -Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`. +Note that, prior to Synapse 1.41, any call which returns a coroutine will need to be wrapped in `ensureDeferred`. As a simple example, retrieving an event from the database: diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 76bb45aff2e1..5f8d20129e1a 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -33,6 +33,19 @@ Let's assume that we expect clients to connect to our server at `https://example.com:8448`. The following sections detail the configuration of the reverse proxy and the homeserver. + +## Homeserver Configuration + +The HTTP configuration will need to be updated for Synapse to correctly record +client IP addresses and generate redirect URLs while behind a reverse proxy. + +In `homeserver.yaml` set `x_forwarded: true` in the port 8008 section and +consider setting `bind_addresses: ['127.0.0.1']` so that the server only +listens to traffic on localhost. (Do not change `bind_addresses` to `127.0.0.1` +when using a containerized Synapse, as that will prevent it from responding +to proxied traffic.) + + ## Reverse-proxy configuration examples **NOTE**: You only need one of these. @@ -239,16 +252,6 @@ relay "matrix_federation" { } ``` -## Homeserver Configuration - -You will also want to set `bind_addresses: ['127.0.0.1']` and -`x_forwarded: true` for port 8008 in `homeserver.yaml` to ensure that -client IP addresses are recorded correctly. - -Having done so, you can then use `https://matrix.example.com` (instead -of `https://matrix.example.com:8448`) as the "Custom server" when -connecting to Synapse from a client. - ## Health check endpoint diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a217f35dba9..3ec76d5abf21 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -210,6 +210,8 @@ presence: # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # +# Note: The value is ignored when an HTTP proxy is in use +# #ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -563,6 +565,19 @@ retention: # #next_link_domain_whitelist: ["matrix.org"] +# Templates to use when generating email or HTML page contents. +# +templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ + ## TLS ## @@ -711,6 +726,15 @@ caches: # #expiry_time: 30m + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m + ## Database ## @@ -963,6 +987,8 @@ media_store_path: "DATADIR/media_store" # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # +# Note: The value is ignored when an HTTP proxy is in use +# #url_preview_ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -1882,6 +1908,9 @@ cas_config: # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. # +# Server admins can configure custom templates for pages related to SSO. See +# https://matrix-org.github.io/synapse/latest/templates.html for more information. +# sso: # A list of client URLs which are whitelisted so that the user does not # have to confirm giving access to their account to the URL. Any client @@ -1914,169 +1943,6 @@ sso: # #update_profile_information: true - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to prompt the user to choose an Identity Provider during - # login: 'sso_login_idp_picker.html'. - # - # This is only used if multiple SSO Identity Providers are configured. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL that the user will be redirected to after - # login. - # - # * server_name: the homeserver's name. - # - # * providers: a list of available Identity Providers. Each element is - # an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # The rendered HTML page should contain a form which submits its results - # back as a GET request, with the following query parameters: - # - # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed - # to the template) - # - # * idp: the 'idp_id' of the chosen IDP. - # - # * HTML page to prompt new users to enter a userid and confirm other - # details: 'sso_auth_account_details.html'. This is only shown if the - # SSO implementation (with any user_mapping_provider) does not return - # a localpart. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * idp: details of the SSO Identity Provider that the user logged in - # with: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * user_attributes: an object containing details about the user that - # we received from the IdP. May have the following attributes: - # - # * display_name: the user's display_name - # * emails: a list of email addresses - # - # The template should render a form which submits the following fields: - # - # * username: the localpart of the user's chosen user id - # - # * HTML page allowing the user to consent to the server's terms and - # conditions. This is only shown for new users, and only if - # `user_consent.require_at_registration` is set. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * user_id: the user's matrix proposed ID. - # - # * user_profile.display_name: the user's proposed display name, if any. - # - # * consent_version: the version of the terms that the user will be - # shown - # - # * terms_url: a link to the page showing the terms. - # - # The template should render a form which submits the following fields: - # - # * accepted_version: the version of the terms accepted by the user - # (ie, 'consent_version' from the input variables). - # - # * HTML page for a confirmation step before redirecting back to the client - # with the login token: 'sso_redirect_confirm.html'. - # - # When rendering, this template is given the following variables: - # - # * redirect_url: the URL the user is about to be redirected to. - # - # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a - # human-readable URL to show to users, not to use it as - # the final address to redirect to. - # - # * server_name: the homeserver's name. - # - # * new_user: a boolean indicating whether this is the user's first time - # logging in. - # - # * user_id: the user's matrix ID. - # - # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. - # None if the user has not set an avatar. - # - # * user_profile.display_name: the user's display name. None if the user - # has not set a display name. - # - # * HTML page which notifies the user that they are authenticating to confirm - # an operation on their account during the user interactive authentication - # process: 'sso_auth_confirm.html'. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL the user is about to be redirected to. - # - # * description: the operation which the user is being asked to confirm - # - # * idp: details of the Identity Provider that we will use to confirm - # the user's identity: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * HTML page shown after a successful user interactive authentication session: - # 'sso_auth_success.html'. - # - # Note that this page must include the JavaScript which notifies of a successful authentication - # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). - # - # This template has no additional variables. - # - # * HTML page shown after a user-interactive authentication session which - # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. - # - # When rendering, this template is given the following variables: - # * server_name: the homeserver's name. - # * user_id_to_verify: the MXID of the user that we are trying to - # validate. - # - # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) - # attempts to login: 'sso_account_deactivated.html'. - # - # This template has no additional variables. - # - # * HTML page to display to users if something goes wrong during the - # OpenID Connect authentication process: 'sso_error.html'. - # - # When rendering, this template is given two variables: - # * error: the technical name of the error - # * error_description: a human-readable message for the error - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # JSON web token integration. The following settings can be used to make # Synapse JSON web tokens for authentication, instead of its internal @@ -2207,6 +2073,9 @@ ui_auth: # Configuration for sending emails from Synapse. # +# Server admins can configure custom templates for email content. See +# https://matrix-org.github.io/synapse/latest/templates.html for more information. +# email: # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. # @@ -2229,6 +2098,14 @@ email: # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # @@ -2275,49 +2152,6 @@ email: # #invite_client_location: https://app.element.io - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * The contents of email notifications of missed events: 'notif_mail.html' and - # 'notif_mail.txt'. - # - # * The contents of account expiry notice emails: 'notice_expiry.html' and - # 'notice_expiry.txt'. - # - # * The contents of password reset emails sent by the homeserver: - # 'password_reset.html' and 'password_reset.txt' - # - # * An HTML page that a user will see when they follow the link in the password - # reset email. The user will be asked to confirm the action before their - # password is reset: 'password_reset_confirmation.html' - # - # * HTML pages for success and failure that a user will see when they confirm - # the password reset flow using the page above: 'password_reset_success.html' - # and 'password_reset_failure.html' - # - # * The contents of address verification emails sent during registration: - # 'registration.html' and 'registration.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent during registration: - # 'registration_success.html' and 'registration_failure.html' - # - # * The contents of address verification emails sent when an address is added - # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent when an address is added - # to a Matrix account: 'add_threepid_success.html' and - # 'add_threepid_failure.html' - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # Subjects to use when sending emails from Synapse. # # The placeholder '%(app)s' will be replaced with the value of the 'app_name' diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 669e60008113..2485ad25edfc 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -24,18 +24,31 @@ handlers: backupCount: 3 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/docs/setup/forward_proxy.md b/docs/setup/forward_proxy.md new file mode 100644 index 000000000000..494c14893b26 --- /dev/null +++ b/docs/setup/forward_proxy.md @@ -0,0 +1,74 @@ +# Using a forward proxy with Synapse + +You can use Synapse with a forward or outbound proxy. An example of when +this is necessary is in corporate environments behind a DMZ (demilitarized zone). +Synapse supports routing outbound HTTP(S) requests via a proxy. Only HTTP(S) +proxy is supported, not SOCKS proxy or anything else. + +## Configure + +The `http_proxy`, `https_proxy`, `no_proxy` environment variables are used to +specify proxy settings. The environment variable is not case sensitive. +- `http_proxy`: Proxy server to use for HTTP requests. +- `https_proxy`: Proxy server to use for HTTPS requests. +- `no_proxy`: Comma-separated list of hosts, IP addresses, or IP ranges in CIDR + format which should not use the proxy. Synapse will directly connect to these hosts. + +The `http_proxy` and `https_proxy` environment variables have the form: `[scheme://][:@][:]` +- Supported schemes are `http://` and `https://`. The default scheme is `http://` + for compatibility reasons; it is recommended to set a scheme. If scheme is set + to `https://` the connection uses TLS between Synapse and the proxy. + + **NOTE**: Synapse validates the certificates. If the certificate is not + valid, then the connection is dropped. +- Default port if not given is `1080`. +- Username and password are optional and will be used to authenticate against + the proxy. + +**Examples** +- HTTP_PROXY=http://USERNAME:PASSWORD@10.0.1.1:8080/ +- HTTPS_PROXY=http://USERNAME:PASSWORD@proxy.example.com:8080/ +- NO_PROXY=master.hostname.example.com,10.1.0.0/16,172.30.0.0/16 + +**NOTE**: +Synapse does not apply the IP blacklist to connections through the proxy (since +the DNS resolution is done by the proxy). It is expected that the proxy or firewall +will apply blacklisting of IP addresses. + +## Connection types + +The proxy will be **used** for: + +- push +- url previews +- phone-home stats +- recaptcha validation +- CAS auth validation +- OpenID Connect +- Outbound federation +- Federation (checking public key revocation) +- Fetching public keys of other servers +- Downloading remote media + +It will **not be used** for: + +- Application Services +- Identity servers +- In worker configurations + - connections between workers + - connections from workers to Redis + +## Troubleshooting + +If a proxy server is used with TLS (HTTPS) and no connections are established, +it is most likely due to the proxy's certificates. To test this, the validation +in Synapse can be deactivated. + +**NOTE**: This has an impact on security and is for testing purposes only! + +To deactivate the certificate validation, the following setting must be made in +[homserver.yaml](../usage/configuration/homeserver_sample_config.md). + +```yaml +use_insecure_ssl_client_just_for_testing_do_not_use: true +``` diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 000000000000..a240f58b54fd --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,239 @@ +# Templates + +Synapse uses parametrised templates to generate the content of emails it sends and +webpages it shows to users. + +By default, Synapse will use the templates listed [here](https://github.com/matrix-org/synapse/tree/master/synapse/res/templates). +Server admins can configure an additional directory for Synapse to look for templates +in, allowing them to specify custom templates: + +```yaml +templates: + custom_templates_directory: /path/to/custom/templates/ +``` + +If this setting is not set, or the files named below are not found within the directory, +default templates from within the Synapse package will be used. + +Templates that are given variables when being rendered are rendered using [Jinja 2](https://jinja.palletsprojects.com/en/2.11.x/). +Templates rendered by Jinja 2 can also access two functions on top of the functions +already available as part of Jinja 2: + +```python +format_ts(value: int, format: str) -> str +``` + +Formats a timestamp in milliseconds. + +Example: `reason.last_sent_ts|format_ts("%c")` + +```python +mxc_to_http(value: str, width: int, height: int, resize_method: str = "crop") -> str +``` + +Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver's +`public_baseurl` configuration setting as the URL's base. + +Example: `message.sender_avatar_url|mxc_to_http(32,32)` + + +## Email templates + +Below are the templates Synapse will look for when generating the content of an email: + +* `notif_mail.html` and `notif_mail.txt`: The contents of email notifications of missed + events. + When rendering, this template is given the following variables: + * `user_display_name`: the display name for the user receiving the notification + * `unsubscribe_link`: the link users can click to unsubscribe from email notifications + * `summary_text`: a summary of the notification(s). The text used can be customised + by configuring the various settings in the `email.subjects` section of the + configuration file. + * `rooms`: a list of rooms containing events to include in the email. Each element is + an object with the following attributes: + * `title`: a human-readable name for the room + * `hash`: a hash of the ID of the room + * `invite`: a boolean, which is `True` if the room is an invite the user hasn't + accepted yet, `False` otherwise + * `notifs`: a list of events, or an empty list if `invite` is `True`. Each element + is an object with the following attributes: + * `link`: a `matrix.to` link to the event + * `ts`: the time in milliseconds at which the event was received + * `messages`: a list of messages containing one message before the event, the + message in the event, and one message after the event. Each element is an + object with the following attributes: + * `event_type`: the type of the event + * `is_historical`: a boolean, which is `False` if the message is the one + that triggered the notification, `True` otherwise + * `id`: the ID of the event + * `ts`: the time in milliseconds at which the event was sent + * `sender_name`: the display name for the event's sender + * `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's + sender + * `sender_hash`: a hash of the user ID of the sender + * `link`: a `matrix.to` link to the room + * `reason`: information on the event that triggered the email to be sent. It's an + object with the following attributes: + * `room_id`: the ID of the room the event was sent in + * `room_name`: a human-readable name for the room the event was sent in + * `now`: the current time in milliseconds + * `received_at`: the time in milliseconds at which the event was received + * `delay_before_mail_ms`: the amount of time in milliseconds Synapse always waits + before ever emailing about a notification (to give the user a chance to respond + to other push or notice the window) + * `last_sent_ts`: the time in milliseconds at which a notification was last sent + for an event in this room + * `throttle_ms`: the minimum amount of time in milliseconds between two + notifications can be sent for this room +* `password_reset.html` and `password_reset.txt`: The contents of password reset emails + sent by the homeserver. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to reset their password. +* `registration.html` and `registration.txt`: The contents of address verification emails + sent during registration. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to validate their email address. +* `add_threepid.html` and `add_threepid.txt`: The contents of address verification emails + sent when an address is added to a Matrix account. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to validate their email address. + + +## HTML page templates for registration and password reset + +Below are the templates Synapse will look for when generating pages related to +registration and password reset: + +* `password_reset_confirmation.html`: An HTML page that a user will see when they follow + the link in the password reset email. The user will be asked to confirm the action + before their password is reset. + When rendering, this template is given the following variables: + * `sid`: the session ID for the password reset + * `token`: the token for the password reset + * `client_secret`: the client secret for the password reset +* `password_reset_success.html` and `password_reset_failure.html`: HTML pages for success + and failure that a user will see when they confirm the password reset flow using the + page above. + When rendering, `password_reset_success.html` is given no variable, and + `password_reset_failure.html` is given a `failure_reason`, which contains the reason + for the password reset failure. +* `registration_success.html` and `registration_failure.html`: HTML pages for success and + failure that a user will see when they follow the link in an address verification email + sent during registration. + When rendering, `registration_success.html` is given no variable, and + `registration_failure.html` is given a `failure_reason`, which contains the reason + for the registration failure. +* `add_threepid_success.html` and `add_threepid_failure.html`: HTML pages for success and + failure that a user will see when they follow the link in an address verification email + sent when an address is added to a Matrix account. + When rendering, `add_threepid_success.html` is given no variable, and + `add_threepid_failure.html` is given a `failure_reason`, which contains the reason + for the registration failure. + + +## HTML page templates for Single Sign-On (SSO) + +Below are the templates Synapse will look for when generating pages related to SSO: + +* `sso_login_idp_picker.html`: HTML page to prompt the user to choose an + Identity Provider during login. + This is only used if multiple SSO Identity Providers are configured. + When rendering, this template is given the following variables: + * `redirect_url`: the URL that the user will be redirected to after + login. + * `server_name`: the homeserver's name. + * `providers`: a list of available Identity Providers. Each element is + an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP + The rendered HTML page should contain a form which submits its results + back as a GET request, with the following query parameters: + * `redirectUrl`: the client redirect URI (ie, the `redirect_url` passed + to the template) + * `idp`: the 'idp_id' of the chosen IDP. +* `sso_auth_account_details.html`: HTML page to prompt new users to enter a + userid and confirm other details. This is only shown if the + SSO implementation (with any `user_mapping_provider`) does not return + a localpart. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `idp`: details of the SSO Identity Provider that the user logged in + with: an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP + * `user_attributes`: an object containing details about the user that + we received from the IdP. May have the following attributes: + * display_name: the user's display_name + * emails: a list of email addresses + The template should render a form which submits the following fields: + * `username`: the localpart of the user's chosen user id +* `sso_new_user_consent.html`: HTML page allowing the user to consent to the + server's terms and conditions. This is only shown for new users, and only if + `user_consent.require_at_registration` is set. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `user_id`: the user's matrix proposed ID. + * `user_profile.display_name`: the user's proposed display name, if any. + * consent_version: the version of the terms that the user will be + shown + * `terms_url`: a link to the page showing the terms. + The template should render a form which submits the following fields: + * `accepted_version`: the version of the terms accepted by the user + (ie, 'consent_version' from the input variables). +* `sso_redirect_confirm.html`: HTML page for a confirmation step before redirecting back + to the client with the login token. + When rendering, this template is given the following variables: + * `redirect_url`: the URL the user is about to be redirected to. + * `display_url`: the same as `redirect_url`, but with the query + parameters stripped. The intention is to have a + human-readable URL to show to users, not to use it as + the final address to redirect to. + * `server_name`: the homeserver's name. + * `new_user`: a boolean indicating whether this is the user's first time + logging in. + * `user_id`: the user's matrix ID. + * `user_profile.avatar_url`: an MXC URI for the user's avatar, if any. + `None` if the user has not set an avatar. + * `user_profile.display_name`: the user's display name. `None` if the user + has not set a display name. +* `sso_auth_confirm.html`: HTML page which notifies the user that they are authenticating + to confirm an operation on their account during the user interactive authentication + process. + When rendering, this template is given the following variables: + * `redirect_url`: the URL the user is about to be redirected to. + * `description`: the operation which the user is being asked to confirm + * `idp`: details of the Identity Provider that we will use to confirm + the user's identity: an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP +* `sso_auth_success.html`: HTML page shown after a successful user interactive + authentication session. + Note that this page must include the JavaScript which notifies of a successful + authentication (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). + This template has no additional variables. +* `sso_auth_bad_user.html`: HTML page shown after a user-interactive authentication + session which does not map correctly onto the expected user. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `user_id_to_verify`: the MXID of the user that we are trying to + validate. +* `sso_account_deactivated.html`: HTML page shown during single sign-on if a deactivated + user (according to Synapse's database) attempts to login. + This template has no additional variables. +* `sso_error.html`: HTML page to display to users if something goes wrong during the + OpenID Connect authentication process. + When rendering, this template is given two variables: + * `error`: the technical name of the error + * `error_description`: a human-readable message for the error diff --git a/docs/upgrade.md b/docs/upgrade.md index ce9167e6de13..e5d386b02f7b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,50 @@ process, for example: ``` +# Upgrading to v1.41.0 + +## Add support for routing outbound HTTP requests via a proxy for federation + +Since Synapse 1.6.0 (2019-11-26) you can set a proxy for outbound HTTP requests via +http_proxy/https_proxy environment variables. This proxy was set for: +- push +- url previews +- phone-home stats +- recaptcha validation +- CAS auth validation +- OpenID Connect +- Federation (checking public key revocation) + +In this version we have added support for outbound requests for: +- Outbound federation +- Downloading remote media +- Fetching public keys of other servers + +These requests use the same proxy configuration. If you have a proxy configuration we +recommend to verify the configuration. It may be necessary to adjust the `no_proxy` +environment variable. + +See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for +details. + +## Deprecation of `template_dir` + +The `template_dir` settings in the `sso`, `account_validity` and `email` sections of the +configuration file are now deprecated. Server admins should use the new +`templates.custom_template_directory` setting in the configuration file and use one single +custom template directory for all aforementioned features. Template file names remain +unchanged. See [the related documentation](https://matrix-org.github.io/synapse/latest/templates.html) +for more information and examples. + +We plan to remove support for these settings in October 2021. + +## `/_synapse/admin/v1/users/{userId}/media` must be handled by media workers + +The [media repository worker documentation](https://matrix-org.github.io/synapse/latest/workers.html#synapseappmedia_repository) +has been updated to reflect that calls to `/_synapse/admin/v1/users/{userId}/media` +must now be handled by media repository workers. This is due to the new `DELETE` method +of this endpoint modifying the media store. + # Upgrading to v1.39.0 ## Deprecation of the current third-party rules module interface diff --git a/docs/workers.md b/docs/workers.md index d8672324c301..2e63f0345288 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -214,6 +214,7 @@ expressions: ^/_matrix/federation/v1/send/ # Client API requests + ^/_matrix/client/(api/v1|r0|unstable)/createRoom$ ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$ @@ -425,10 +426,12 @@ Handles the media repository. It can handle all endpoints starting with: ^/_synapse/admin/v1/user/.*/media.*$ ^/_synapse/admin/v1/media/.*$ ^/_synapse/admin/v1/quarantine_media/.*$ + ^/_synapse/admin/v1/users/.*/media$ You should also set `enable_media_repo: False` in the shared configuration file to stop the main synapse running background jobs related to managing the -media repository. +media repository. Note that doing so will prevent the main process from being +able to handle the above endpoints. In the `media_repository` worker configuration file, configure the http listener to expose the `media` resource. For example: diff --git a/mypy.ini b/mypy.ini index 8717ae738e31..e1b9405daa85 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,6 +86,7 @@ files = tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, + tests/handlers/test_room_summary.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 0ed1c679fd94..e9f89e38efaa 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -20,12 +20,12 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional, Sequence DISTS = ( - "debian:buster", + "debian:buster", # oldstable: EOL 2022-08 "debian:bullseye", + "debian:bookworm", "debian:sid", "ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23) "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) - "ubuntu:groovy", # 20.10 (EOL 2021-07-07) "ubuntu:hirsute", # 21.04 (EOL 2022-01-05) ) diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index af6d32e3321c..393a548d5897 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -11,7 +11,7 @@ set -e git remote set-branches --add origin develop git fetch -q origin develop -pr="$BUILDKITE_PULL_REQUEST" +pr="$PULL_REQUEST_NUMBER" # if there are changes in the debian directory, check that the debian changelog # has been updated diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index cba015d942ca..5d0ef8dd3a73 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 869eb2372d51..809eff166ab2 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -94,7 +94,7 @@ else "scripts-dev/build_debian_packages" "scripts-dev/sign_json" "scripts-dev/update_database" - "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite" + "contrib" "synctl" "setup.py" "synmark" "stubs" ".ci" ) fi fi diff --git a/synapse/__init__.py b/synapse/__init__.py index 919293cd80c5..ef3770262e8f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ except ImportError: pass -__version__ = "1.40.0" +__version__ = "1.41.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index f32a40ba4ae6..8abcdfd4fd9a 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -76,6 +76,8 @@ class RoomVersion: # MSC2716: Adds m.room.power_levels -> content.historical field to control # whether "insertion", "chunk", "marker" events can be sent msc2716_historical = attr.ib(type=bool) + # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events + msc2716_redactions = attr.ib(type=bool) class RoomVersions: @@ -92,6 +94,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V2 = RoomVersion( "2", @@ -106,6 +109,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V3 = RoomVersion( "3", @@ -120,6 +124,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V4 = RoomVersion( "4", @@ -134,6 +139,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V5 = RoomVersion( "5", @@ -148,6 +154,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V6 = RoomVersion( "6", @@ -162,6 +169,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -176,6 +184,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V7 = RoomVersion( "7", @@ -190,6 +199,22 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=True, msc2716_historical=False, + msc2716_redactions=False, + ) + V8 = RoomVersion( + "8", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, ) MSC2716 = RoomVersion( "org.matrix.msc2716", @@ -204,10 +229,11 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=True, msc2716_historical=True, + msc2716_redactions=False, ) - V8 = RoomVersion( - "8", - RoomDisposition.STABLE, + MSC2716v2 = RoomVersion( + "org.matrix.msc2716v2", + RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, @@ -215,9 +241,10 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, - msc3083_join_rules=True, + msc3083_join_rules=False, msc2403_knocking=True, - msc2716_historical=False, + msc2716_historical=True, + msc2716_redactions=True, ) @@ -266,7 +293,7 @@ class RoomVersionCapability: ), RoomVersionCapability( "restricted", - None, + RoomVersions.V8, lambda room_version: room_version.msc3083_join_rules, ), ) diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 3234d9ebba07..7396db93c62f 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -38,7 +38,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -58,7 +57,6 @@ class AdminCmdSlavedStore( SlavedPushRuleStore, SlavedEventStore, SlavedClientIpStore, - RoomStore, BaseSlavedStore, ): pass diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3b7131af8fa2..845e6a822046 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,42 +64,41 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, login, presence, room -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.profile import ( - ProfileAvatarURLRestServlet, - ProfileDisplaynameRestServlet, - ProfileRestServlet, -) -from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import ( +from synapse.rest.client import ( account_data, + events, groups, + login, + presence, read_marker, receipts, + room, room_keys, sync, tags, user_directory, ) -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.rest.client.v2_alpha.account import ThreepidRestServlet -from synapse.rest.client.v2_alpha.account_data import ( - AccountDataServlet, - RoomAccountDataServlet, -) -from synapse.rest.client.v2_alpha.devices import DevicesRestServlet -from synapse.rest.client.v2_alpha.keys import ( +from synapse.rest.client._base import client_patterns +from synapse.rest.client.account import ThreepidRestServlet +from synapse.rest.client.account_data import AccountDataServlet, RoomAccountDataServlet +from synapse.rest.client.devices import DevicesRestServlet +from synapse.rest.client.initial_sync import InitialSyncRestServlet +from synapse.rest.client.keys import ( KeyChangesServlet, KeyQueryServlet, OneTimeKeyServlet, ) -from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet +from synapse.rest.client.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) +from synapse.rest.client.push_rule import PushRuleRestServlet +from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.client.voip import VoipRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -114,6 +113,7 @@ MonthlyActiveUsersWorkerStore, ) from synapse.storage.databases.main.presence import PresenceStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore @@ -237,7 +237,7 @@ class GenericWorkerSlavedStore( ClientIpWorkerStore, SlavedEventStore, SlavedKeyStore, - RoomStore, + RoomWorkerStore, DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, diff --git a/synapse/config/_base.py b/synapse/config/_base.py index d6ec618f8f52..2cc242782add 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -237,13 +237,14 @@ def read_template(self, filename: str) -> jinja2.Template: def read_templates( self, filenames: List[str], - custom_template_directory: Optional[str] = None, + custom_template_directories: Optional[Iterable[str]] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. This function will attempt to load the given templates from the default Synapse - template directory. If `custom_template_directory` is supplied, that directory - is tried first. + template directory. If `custom_template_directories` is supplied, any directory + in this list is tried (in the order they appear in the list) before trying + Synapse's default directory. Files read are treated as Jinja templates. The templates are not rendered yet and have autoescape enabled. @@ -251,8 +252,8 @@ def read_templates( Args: filenames: A list of template filenames to read. - custom_template_directory: A directory to try to look for the templates - before using the default Synapse template directory instead. + custom_template_directories: A list of directory to try to look for the + templates before using the default Synapse template directory instead. Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. @@ -260,20 +261,26 @@ def read_templates( Returns: A list of jinja2 templates. """ - search_directories = [self.default_template_dir] - - # The loader will first look in the custom template directory (if specified) for the - # given filename. If it doesn't find it, it will use the default template dir instead - if custom_template_directory: - # Check that the given template directory exists - if not self.path_exists(custom_template_directory): - raise ConfigError( - "Configured template directory does not exist: %s" - % (custom_template_directory,) - ) + search_directories = [] + + # The loader will first look in the custom template directories (if specified) + # for the given filename. If it doesn't find it, it will use the default + # template dir instead. + if custom_template_directories is not None: + for custom_template_directory in custom_template_directories: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.append(custom_template_directory) - # Search the custom template directory as well - search_directories.insert(0, custom_template_directory) + # Append the default directory at the end of the list so Jinja can fallback on it + # if a template is missing from any custom directory. + search_directories.append(self.default_template_dir) # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 6be4eafe5582..52e63ab1f6f2 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -78,6 +78,11 @@ def read_config(self, config, **kwargs): ) # Read and store template content + custom_template_directories = ( + self.root.server.custom_template_directory, + account_validity_template_dir, + ) + ( self.account_validity_account_renewed_template, self.account_validity_account_previously_renewed_template, @@ -88,5 +93,5 @@ def read_config(self, config, **kwargs): "account_previously_renewed.html", invalid_token_template_filename, ], - account_validity_template_dir, + (td for td in custom_template_directories if td), ) diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 8d5f38b5d934..d119427ad864 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -151,6 +151,15 @@ def generate_config_section(self, **kwargs): # entries are never evicted based on time. # #expiry_time: 30m + + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m """ def read_config(self, config, **kwargs): @@ -212,6 +221,10 @@ def read_config(self, config, **kwargs): else: self.expiry_time_msec = None + self.sync_response_cache_duration = self.parse_duration( + cache_config.get("sync_response_cache_duration", 0) + ) + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 8d8f166e9bfb..4477419196c2 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -80,6 +80,12 @@ def read_config(self, config, **kwargs): self.require_transport_security = email_config.get( "require_transport_security", False ) + self.enable_smtp_tls = email_config.get("enable_tls", True) + if self.require_transport_security and not self.enable_smtp_tls: + raise ConfigError( + "email.require_transport_security requires email.enable_tls to be true" + ) + if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -251,7 +257,14 @@ def read_config(self, config, **kwargs): registration_template_success_html, add_threepid_template_success_html, ], - template_dir, + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) # Render templates that do not contain any placeholders @@ -291,7 +304,14 @@ def read_config(self, config, **kwargs): self.email_notif_template_text, ) = self.read_templates( [notif_template_html, notif_template_text], - template_dir, + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) self.email_notif_for_new_users = email_config.get( @@ -314,7 +334,14 @@ def read_config(self, config, **kwargs): self.account_validity_template_text, ) = self.read_templates( [expiry_template_html, expiry_template_text], - template_dir, + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) subjects_config = email_config.get("subjects", {}) @@ -346,6 +373,9 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): """\ # Configuration for sending emails from Synapse. # + # Server admins can configure custom templates for email content. See + # https://matrix-org.github.io/synapse/latest/templates.html for more information. + # email: # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. # @@ -368,6 +398,14 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # @@ -414,49 +452,6 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #invite_client_location: https://app.element.io - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * The contents of email notifications of missed events: 'notif_mail.html' and - # 'notif_mail.txt'. - # - # * The contents of account expiry notice emails: 'notice_expiry.html' and - # 'notice_expiry.txt'. - # - # * The contents of password reset emails sent by the homeserver: - # 'password_reset.html' and 'password_reset.txt' - # - # * An HTML page that a user will see when they follow the link in the password - # reset email. The user will be asked to confirm the action before their - # password is reset: 'password_reset_confirmation.html' - # - # * HTML pages for success and failure that a user will see when they confirm - # the password reset flow using the page above: 'password_reset_success.html' - # and 'password_reset_failure.html' - # - # * The contents of address verification emails sent during registration: - # 'registration.html' and 'registration.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent during registration: - # 'registration_success.html' and 'registration_failure.html' - # - # * The contents of address verification emails sent when an address is added - # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent when an address is added - # to a Matrix account: 'add_threepid_success.html' and - # 'add_threepid_failure.html' - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # Subjects to use when sending emails from Synapse. # # The placeholder '%%(app)s' will be replaced with the value of the 'app_name' diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4c60ee8c2859..907df9591a85 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -37,4 +37,7 @@ def read_config(self, config: JsonDict, **kwargs): self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) # MSC3244 (room version capabilities) - self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) + + # MSC3266 (room summary api) + self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index ad4e6e61c3bf..4a398a7932cc 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -67,18 +67,31 @@ backupCount: 3 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 0dfb3a227a3b..7481f3bf5f0f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from collections import namedtuple from typing import Dict, List +from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements @@ -22,6 +24,8 @@ from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + DEFAULT_THUMBNAIL_SIZES = [ {"width": 32, "height": 32, "method": "crop"}, {"width": 96, "height": 96, "method": "crop"}, @@ -36,6 +40,9 @@ # method: %(method)s """ +HTTP_PROXY_SET_WARNING = """\ +The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -180,12 +187,17 @@ def read_config(self, config, **kwargs): e.message # noqa: B306, DependencyException.message is a property ) + proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: - raise ConfigError( - "For security, you must specify an explicit target IP address " - "blacklist in url_preview_ip_range_blacklist for url previewing " - "to work" - ) + if "http" not in proxy_env or "https" not in proxy_env: + raise ConfigError( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + else: + if "http" in proxy_env or "https" in proxy_env: + logger.warning("".join(HTTP_PROXY_SET_WARNING)) # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. @@ -292,6 +304,8 @@ def generate_config_section(self, data_dir_path, **kwargs): # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # + # Note: The value is ignored when an HTTP proxy is in use + # #url_preview_ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/config/server.py b/synapse/config/server.py index b9e0c0b30093..849479591971 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -710,6 +710,18 @@ class LimitRemoteRoomsConfig: # Turn the list into a set to improve lookup speed. self.next_link_domain_whitelist = set(next_link_domain_whitelist) + templates_config = config.get("templates") or {} + if not isinstance(templates_config, dict): + raise ConfigError("The 'templates' section must be a dictionary") + + self.custom_template_directory = templates_config.get( + "custom_template_directory" + ) + if self.custom_template_directory is not None and not isinstance( + self.custom_template_directory, str + ): + raise ConfigError("'custom_template_directory' must be a string") + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -960,6 +972,8 @@ def generate_config_section( # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # + # Note: The value is ignored when an HTTP proxy is in use + # #ip_range_blacklist: %(ip_range_blacklist)s @@ -1282,6 +1296,19 @@ def generate_config_section( # all domains. # #next_link_domain_whitelist: ["matrix.org"] + + # Templates to use when generating email or HTML page contents. + # + templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ """ % locals() ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index d0f04cf8e6b2..fe1177ab8109 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -45,6 +45,11 @@ def read_config(self, config, **kwargs): self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk + custom_template_directories = ( + self.root.server.custom_template_directory, + self.sso_template_dir, + ) + ( self.sso_login_idp_picker_template, self.sso_redirect_confirm_template, @@ -63,7 +68,7 @@ def read_config(self, config, **kwargs): "sso_auth_success.html", "sso_auth_bad_user.html", ], - self.sso_template_dir, + (td for td in custom_template_directories if td), ) # These templates have no placeholders, so render them here @@ -94,6 +99,9 @@ def generate_config_section(self, **kwargs): # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. # + # Server admins can configure custom templates for pages related to SSO. See + # https://matrix-org.github.io/synapse/latest/templates.html for more information. + # sso: # A list of client URLs which are whitelisted so that the user does not # have to confirm giving access to their account to the URL. Any client @@ -125,167 +133,4 @@ def generate_config_section(self, **kwargs): # information when first signing in. Defaults to false. # #update_profile_information: true - - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to prompt the user to choose an Identity Provider during - # login: 'sso_login_idp_picker.html'. - # - # This is only used if multiple SSO Identity Providers are configured. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL that the user will be redirected to after - # login. - # - # * server_name: the homeserver's name. - # - # * providers: a list of available Identity Providers. Each element is - # an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # The rendered HTML page should contain a form which submits its results - # back as a GET request, with the following query parameters: - # - # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed - # to the template) - # - # * idp: the 'idp_id' of the chosen IDP. - # - # * HTML page to prompt new users to enter a userid and confirm other - # details: 'sso_auth_account_details.html'. This is only shown if the - # SSO implementation (with any user_mapping_provider) does not return - # a localpart. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * idp: details of the SSO Identity Provider that the user logged in - # with: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * user_attributes: an object containing details about the user that - # we received from the IdP. May have the following attributes: - # - # * display_name: the user's display_name - # * emails: a list of email addresses - # - # The template should render a form which submits the following fields: - # - # * username: the localpart of the user's chosen user id - # - # * HTML page allowing the user to consent to the server's terms and - # conditions. This is only shown for new users, and only if - # `user_consent.require_at_registration` is set. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * user_id: the user's matrix proposed ID. - # - # * user_profile.display_name: the user's proposed display name, if any. - # - # * consent_version: the version of the terms that the user will be - # shown - # - # * terms_url: a link to the page showing the terms. - # - # The template should render a form which submits the following fields: - # - # * accepted_version: the version of the terms accepted by the user - # (ie, 'consent_version' from the input variables). - # - # * HTML page for a confirmation step before redirecting back to the client - # with the login token: 'sso_redirect_confirm.html'. - # - # When rendering, this template is given the following variables: - # - # * redirect_url: the URL the user is about to be redirected to. - # - # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a - # human-readable URL to show to users, not to use it as - # the final address to redirect to. - # - # * server_name: the homeserver's name. - # - # * new_user: a boolean indicating whether this is the user's first time - # logging in. - # - # * user_id: the user's matrix ID. - # - # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. - # None if the user has not set an avatar. - # - # * user_profile.display_name: the user's display name. None if the user - # has not set a display name. - # - # * HTML page which notifies the user that they are authenticating to confirm - # an operation on their account during the user interactive authentication - # process: 'sso_auth_confirm.html'. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL the user is about to be redirected to. - # - # * description: the operation which the user is being asked to confirm - # - # * idp: details of the Identity Provider that we will use to confirm - # the user's identity: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * HTML page shown after a successful user interactive authentication session: - # 'sso_auth_success.html'. - # - # Note that this page must include the JavaScript which notifies of a successful authentication - # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). - # - # This template has no additional variables. - # - # * HTML page shown after a user-interactive authentication session which - # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. - # - # When rendering, this template is given the following variables: - # * server_name: the homeserver's name. - # * user_id_to_verify: the MXID of the user that we are trying to - # validate. - # - # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) - # attempts to login: 'sso_account_deactivated.html'. - # - # This template has no additional variables. - # - # * HTML page to display to users if something goes wrong during the - # OpenID Connect authentication process: 'sso_error.html'. - # - # When rendering, this template is given two variables: - # * error: the technical name of the error - # * error_description: a human-readable message for the error - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" """ diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 0298af4c02d7..a730c1719a95 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -396,10 +396,11 @@ def __str__(self): return self.__repr__() def __repr__(self): - return "" % ( + return "" % ( self.get("event_id", None), self.get("type", None), self.get("state_key", None), + self.internal_metadata.is_outlier(), ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0c07f62f44b..b6da2f60af99 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -17,7 +17,7 @@ from frozendict import frozendict -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results @@ -135,6 +135,12 @@ def add_fields(*fields): add_fields("history_visibility") elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: add_fields("redacts") + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION: + add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID) + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK: + add_fields(EventContentFields.MSC2716_CHUNK_ID) + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER: + add_fields(EventContentFields.MSC2716_MARKER_INSERTION) allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7a10da15a89..29979414e3d7 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,7 +1108,8 @@ async def get_public_rooms( The response from the remote server. Raises: - HttpResponseException: There was an exception returned from the remote server + HttpResponseException / RequestSendFailed: There was an exception + returned from the remote server SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom requests over federation @@ -1289,8 +1290,136 @@ async def send_request(destination: str) -> FederationSpaceSummaryResult: failover_on_unknown_endpoint=True, ) + async def get_room_hierarchy( + self, + destinations: Iterable[str], + room_id: str, + suggested_only: bool, + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + """ + Call other servers to get a hierarchy of the given room. + + Performs simple data validates and parsing of the response. -@attr.s(frozen=True, slots=True) + Args: + destinations: The remote servers. We will try them in turn, omitting any + that have been blacklisted. + room_id: ID of the space to be queried + suggested_only: If true, ask the remote server to only return children + with the "suggested" flag set + + Returns: + A tuple of: + The room as a JSON dictionary. + A list of children rooms, as JSON dictionaries. + A list of inaccessible children room IDs. + + Raises: + SynapseError if we were unable to get a valid summary from any of the + remote servers + """ + + async def send_request( + destination: str, + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) + + room = res.get("room") + if not isinstance(room, dict): + raise InvalidResponseError("'room' must be a dict") + + # Validate children_state of the room. + children_state = room.get("children_state", []) + if not isinstance(children_state, Sequence): + raise InvalidResponseError("'room.children_state' must be a list") + if any(not isinstance(e, dict) for e in children_state): + raise InvalidResponseError("Invalid event in 'children_state' list") + try: + [ + FederationSpaceSummaryEventResult.from_json_dict(e) + for e in children_state + ] + except ValueError as e: + raise InvalidResponseError(str(e)) + + # Validate the children rooms. + children = res.get("children", []) + if not isinstance(children, Sequence): + raise InvalidResponseError("'children' must be a list") + if any(not isinstance(r, dict) for r in children): + raise InvalidResponseError("Invalid room in 'children' list") + + # Validate the inaccessible children. + inaccessible_children = res.get("inaccessible_children", []) + if not isinstance(inaccessible_children, Sequence): + raise InvalidResponseError("'inaccessible_children' must be a list") + if any(not isinstance(r, str) for r in inaccessible_children): + raise InvalidResponseError( + "Invalid room ID in 'inaccessible_children' list" + ) + + return room, children, inaccessible_children + + try: + return await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) + except SynapseError as e: + # Fallback to the old federation API and translate the results if + # no servers implement the new API. + # + # The algorithm below is a bit inefficient as it only attempts to + # get information for the requested room, but the legacy API may + # return additional layers. + if e.code == 502: + legacy_result = await self.get_space_summary( + destinations, + room_id, + suggested_only, + max_rooms_per_space=None, + exclude_rooms=[], + ) + + # Find the requested room in the response (and remove it). + for _i, room in enumerate(legacy_result.rooms): + if room.get("room_id") == room_id: + break + else: + # The requested room was not returned, nothing we can do. + raise + requested_room = legacy_result.rooms.pop(_i) + + # Find any children events of the requested room. + children_events = [] + children_room_ids = set() + for event in legacy_result.events: + if event.room_id == room_id: + children_events.append(event.data) + children_room_ids.add(event.state_key) + # And add them under the requested room. + requested_room["children_state"] = children_events + + # Find the children rooms. + children = [] + for room in legacy_result.rooms: + if room.get("room_id") in children_room_ids: + children.append(room) + + # It isn't clear from the response whether some of the rooms are + # not accessible. + return requested_room, children, () + + raise + + +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: """Represents a single event in the result of a successful get_space_summary call. @@ -1299,12 +1428,13 @@ class FederationSpaceSummaryEventResult: object attributes. """ - event_type = attr.ib(type=str) - state_key = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + event_type: str + room_id: str + state_key: str + via: Sequence[str] # the raw data, including the above keys - data = attr.ib(type=JsonDict) + data: JsonDict @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": @@ -1321,6 +1451,10 @@ def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": if not isinstance(event_type, str): raise ValueError("Invalid event: 'event_type' must be a str") + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") + state_key = d.get("state_key") if not isinstance(state_key, str): raise ValueError("Invalid event: 'state_key' must be a str") @@ -1335,15 +1469,15 @@ def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") - return cls(event_type, state_key, via, d) + return cls(event_type, room_id, state_key, via, d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms = attr.ib(type=Sequence[JsonDict]) - events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + rooms: List[JsonDict] + events: Sequence[FederationSpaceSummaryEventResult] @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": @@ -1356,7 +1490,7 @@ def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": ValueError if d is not a valid /spaces/ response """ rooms = d.get("rooms") - if not isinstance(rooms, Sequence): + if not isinstance(rooms, List): raise ValueError("'rooms' must be a list") if any(not isinstance(r, dict) for r in rooms): raise ValueError("Invalid room in 'rooms' list") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 145b9161d985..afd8f8580a2f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -195,13 +195,17 @@ async def on_backfill_request( origin, room_id, versions, limit ) - res = self._transaction_from_pdus(pdus).get_dict() + res = self._transaction_dict_from_pdus(pdus) return 200, res async def on_incoming_transaction( - self, origin: str, transaction_data: JsonDict - ) -> Tuple[int, Dict[str, Any]]: + self, + origin: str, + transaction_id: str, + destination: str, + transaction_data: JsonDict, + ) -> Tuple[int, JsonDict]: # If we receive a transaction we should make sure that kick off handling # any old events in the staging area. if not self._started_handling_of_staged_events: @@ -212,8 +216,14 @@ async def on_incoming_transaction( # accurate as possible. request_time = self._clock.time_msec() - transaction = Transaction(**transaction_data) - transaction_id = transaction.transaction_id # type: ignore + transaction = Transaction( + transaction_id=transaction_id, + destination=destination, + origin=origin, + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore + pdus=transaction_data.get("pdus"), # type: ignore + edus=transaction_data.get("edus"), + ) if not transaction_id: raise Exception("Transaction missing transaction_id") @@ -221,9 +231,7 @@ async def on_incoming_transaction( logger.debug("[%s] Got transaction", transaction_id) # Reject malformed transactions early: reject if too many PDUs/EDUs - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): + if len(transaction.pdus) > 50 or len(transaction.edus) > 100: logger.info("Transaction PDU or EDU count too large. Returning 400") return 400, {} @@ -263,7 +271,7 @@ async def _on_incoming_transaction_inner( # CRITICAL SECTION: the first thing we must do (before awaiting) is # add an entry to _active_transactions. assert origin not in self._active_transactions - self._active_transactions[origin] = transaction.transaction_id # type: ignore + self._active_transactions[origin] = transaction.transaction_id try: result = await self._handle_incoming_transaction( @@ -291,11 +299,11 @@ async def _handle_incoming_transaction( if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, # type: ignore + transaction.transaction_id, ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore + logger.debug("[%s] Transaction is new", transaction.transaction_id) # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients @@ -334,7 +342,7 @@ async def _handle_pdus_in_txn( report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) # type: ignore + received_pdus_counter.inc(len(transaction.pdus)) origin_host, _ = parse_server_name(origin) @@ -342,7 +350,7 @@ async def _handle_pdus_in_txn( newest_pdu_ts = 0 - for p in transaction.pdus: # type: ignore + for p in transaction.pdus: # FIXME (richardv): I don't think this works: # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: @@ -436,10 +444,10 @@ async def process_pdu(pdu: EventBase) -> JsonDict: return pdu_results - async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None: """Process the EDUs in a received transaction.""" - async def _process_edu(edu_dict): + async def _process_edu(edu_dict: JsonDict) -> None: received_edus_counter.inc() edu = Edu( @@ -452,7 +460,7 @@ async def _process_edu(edu_dict): await concurrently_execute( _process_edu, - getattr(transaction, "edus", []), + transaction.edus, TRANSACTION_CONCURRENCY_LIMIT, ) @@ -538,7 +546,7 @@ async def on_pdu_request( pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: - return 200, self._transaction_from_pdus([pdu]).get_dict() + return 200, self._transaction_dict_from_pdus([pdu]) else: return 404, "" @@ -879,18 +887,20 @@ async def on_openid_userinfo(self, token: str) -> Optional[str]: ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: + def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict: """Returns a new Transaction containing the given PDUs suitable for transmission. """ time_now = self._clock.time_msec() pdus = [p.get_pdu_json(time_now) for p in pdu_list] return Transaction( + # Just need a dummy transaction ID and destination since it won't be used. + transaction_id="", origin=self.server_name, pdus=pdus, origin_server_ts=int(time_now), - destination=None, - ) + destination="", + ).get_dict() async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """Process a PDU received in a federation /send/ transaction. @@ -962,13 +972,18 @@ async def _process_incoming_pdus_in_room_inner( # the room, so instead of pulling the event out of the DB and parsing # the event we just pull out the next event ID and check if that matches. if latest_event is not None and latest_origin is not None: - ( - next_origin, - next_event_id, - ) = await self.store.get_next_staged_event_id_for_room(room_id) - if next_origin != latest_origin or next_event_id != latest_event.event_id: + result = await self.store.get_next_staged_event_id_for_room(room_id) + if result is None: latest_origin = None latest_event = None + else: + next_origin, next_event_id = result + if ( + next_origin != latest_origin + or next_event_id != latest_event.event_id + ): + latest_origin = None + latest_event = None if latest_origin is None or latest_event is None: next = await self.store.get_next_staged_event_for_room( @@ -988,6 +1003,7 @@ async def _process_incoming_pdus_in_room_inner( # has started processing). while True: async with lock: + logger.info("handling received PDU: %s", event) try: await self.handler.on_receive_pdu( origin, event, sent_to_us_directly=True diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 2f9c9bc2cdc8..4fead6ca2954 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -45,7 +45,7 @@ async def have_responded( `None` if we have not previously responded to this transaction or a 2-tuple of `(int, dict)` representing the response code and response body. """ - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") @@ -56,7 +56,7 @@ async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: """Persist how we responded to a transaction.""" - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 72a635830b9a..dc555cca0bbf 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -27,6 +27,7 @@ tags, whitelisted_homeserver, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.metrics import measure_func @@ -104,13 +105,13 @@ async def send_new_transaction( len(edus), ) - transaction = Transaction.create_new( + transaction = Transaction( origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self._server_name, destination=destination, - pdus=pdus, - edus=edus, + pdus=[p.get_pdu_json() for p in pdus], + edus=[edu.get_dict() for edu in edus], ) self._next_txn_id += 1 @@ -131,7 +132,7 @@ async def send_new_transaction( # FIXME (richardv): I also believe it no longer works. We (now?) store # "age_ts" in "unsigned" rather than at the top level. See # https://github.com/matrix-org/synapse/issues/8429. - def json_data_cb(): + def json_data_cb() -> JsonDict: data = transaction.get_dict() now = int(self.clock.time_msec()) if "pdus" in data: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 6a8d3ad4fe6d..8b247fe2066d 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -143,7 +143,7 @@ async def send_transaction( """Sends the given Transaction to its destination Args: - transaction (Transaction) + transaction Returns: Succeeds when we get a 2xx HTTP response. The result @@ -1177,6 +1177,28 @@ async def get_space_summary( destination=destination, path=path, data=params ) + async def get_room_hierarchy( + self, + destination: str, + room_id: str, + suggested_only: bool, + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/hierarchy/%s", room_id + ) + + return await self.client.get_json( + destination=destination, + path=path, + args={"suggested_only": "true" if suggested_only else "false"}, + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py deleted file mode 100644 index 5e059d6e09d4..000000000000 --- a/synapse/federation/transport/server.py +++ /dev/null @@ -1,2139 +0,0 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. -# Copyright 2020 Sorunome -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import logging -import re -from typing import ( - Container, - Dict, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from typing_extensions import Literal - -import synapse -from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH -from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.room_versions import RoomVersions -from synapse.api.urls import ( - FEDERATION_UNSTABLE_PREFIX, - FEDERATION_V1_PREFIX, - FEDERATION_V2_PREFIX, -) -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.server import HttpServer, JsonResource -from synapse.http.servlet import ( - parse_boolean_from_args, - parse_integer_from_args, - parse_json_object_from_request, - parse_string_from_args, - parse_strings_from_args, -) -from synapse.logging import opentracing -from synapse.logging.context import run_in_background -from synapse.logging.opentracing import ( - SynapseTags, - start_active_span, - start_active_span_from_request, - tags, - whitelisted_homeserver, -) -from synapse.server import HomeServer -from synapse.types import JsonDict, ThirdPartyInstanceID, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import parse_and_validate_server_name -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger(__name__) - - -class TransportLayerServer(JsonResource): - """Handles incoming federation HTTP requests""" - - def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): - """Initialize the TransportLayerServer - - Will by default register all servlets. For custom behaviour, pass in - a list of servlet_groups to register. - - Args: - hs: homeserver - servlet_groups: List of servlet groups to register. - Defaults to ``DEFAULT_SERVLET_GROUPS``. - """ - self.hs = hs - self.clock = hs.get_clock() - self.servlet_groups = servlet_groups - - super().__init__(hs, canonical_json=False) - - self.authenticator = Authenticator(hs) - self.ratelimiter = hs.get_federation_ratelimiter() - - self.register_servlets() - - def register_servlets(self) -> None: - register_servlets( - self.hs, - resource=self, - ratelimiter=self.ratelimiter, - authenticator=self.authenticator, - servlet_groups=self.servlet_groups, - ) - - -class AuthenticationError(SynapseError): - """There was a problem authenticating the request""" - - -class NoAuthenticationError(AuthenticationError): - """The request had no authentication information""" - - -class Authenticator: - def __init__(self, hs: HomeServer): - self._clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.server_name = hs.hostname - self.store = hs.get_datastore() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifier = hs.get_notifier() - - self.replication_client = None - if hs.config.worker.worker_app: - self.replication_client = hs.get_tcp_replication() - - # A method just so we can pass 'self' as the authenticator to the Servlets - async def authenticate_request(self, request, content): - now = self._clock.time_msec() - json_request = { - "method": request.method.decode("ascii"), - "uri": request.uri.decode("ascii"), - "destination": self.server_name, - "signatures": {}, - } - - if content is not None: - json_request["content"] = content - - origin = None - - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - - if not auth_headers: - raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED - ) - - for auth in auth_headers: - if auth.startswith(b"X-Matrix"): - (origin, key, sig) = _parse_auth_header(auth) - json_request["origin"] = origin - json_request["signatures"].setdefault(origin, {})[key] = sig - - if ( - self.federation_domain_whitelist is not None - and origin not in self.federation_domain_whitelist - ): - raise FederationDeniedError(origin) - - if origin is None or not json_request["signatures"]: - raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED - ) - - await self.keyring.verify_json_for_server( - origin, - json_request, - now, - ) - - logger.debug("Request from %s", origin) - request.requester = origin - - # If we get a valid signed request from the other side, its probably - # alive - retry_timings = await self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings.retry_last_ts: - run_in_background(self._reset_retry_timings, origin) - - return origin - - async def _reset_retry_timings(self, origin): - try: - logger.info("Marking origin %r as up", origin) - await self.store.set_destination_retry_timings(origin, None, 0, 0) - - # Inform the relevant places that the remote server is back up. - self.notifier.notify_remote_server_up(origin) - if self.replication_client: - # If we're on a worker we try and inform master about this. The - # replication client doesn't hook into the notifier to avoid - # infinite loops where we send a `REMOTE_SERVER_UP` command to - # master, which then echoes it back to us which in turn pokes - # the notifier. - self.replication_client.send_remote_server_up(origin) - - except Exception: - logger.exception("Error resetting retry timings on %s", origin) - - -def _parse_auth_header(header_bytes): - """Parse an X-Matrix auth header - - Args: - header_bytes (bytes): header value - - Returns: - Tuple[str, str, str]: origin, key id, signature. - - Raises: - AuthenticationError if the header could not be parsed - """ - try: - header_str = header_bytes.decode("utf-8") - params = header_str.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith('"'): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - - # ensure that the origin is a valid server name - parse_and_validate_server_name(origin) - - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return origin, key, sig - except Exception as e: - logger.warning( - "Error parsing auth header '%s': %s", - header_bytes.decode("ascii", "replace"), - e, - ) - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - - -class BaseFederationServlet: - """Abstract base class for federation servlet classes. - - The servlet object should have a PATH attribute which takes the form of a regexp to - match against the request path (excluding the /federation/v1 prefix). - - The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match - the appropriate HTTP method. These methods must be *asynchronous* and have the - signature: - - on_(self, origin, content, query, **kwargs) - - With arguments: - - origin (unicode|None): The authenticated server_name of the calling server, - unless REQUIRE_AUTH is set to False and authentication failed. - - content (unicode|None): decoded json body of the request. None if the - request was a GET. - - query (dict[bytes, list[bytes]]): Query params from the request. url-decoded - (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded - yet. - - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. - - Returns: - Optional[Tuple[int, object]]: either (response code, response object) to - return a JSON response, or None if the request has already been handled. - - Raises: - SynapseError: to return an error code - - Exception: other exceptions will be caught, logged, and a 500 will be - returned. - """ - - PATH = "" # Overridden in subclasses, the regex to match against the path. - - REQUIRE_AUTH = True - - PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version - - RATELIMIT = True # Whether to rate limit requests or not - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - self.hs = hs - self.authenticator = authenticator - self.ratelimiter = ratelimiter - self.server_name = server_name - - def _wrap(self, func): - authenticator = self.authenticator - ratelimiter = self.ratelimiter - - @functools.wraps(func) - async def new_func(request, *args, **kwargs): - """A callback which can be passed to HttpServer.RegisterPaths - - Args: - request (twisted.web.http.Request): - *args: unused? - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. - - Returns: - Tuple[int, object]|None: (response code, response object) as returned by - the callback method. None if the request has already been handled. - """ - content = None - if request.method in [b"PUT", b"POST"]: - # TODO: Handle other method types? other content types? - content = parse_json_object_from_request(request) - - try: - origin = await authenticator.authenticate_request(request, content) - except NoAuthenticationError: - origin = None - if self.REQUIRE_AUTH: - logger.warning( - "authenticate_request failed: missing authentication" - ) - raise - except Exception as e: - logger.warning("authenticate_request failed: %s", e) - raise - - request_tags = { - SynapseTags.REQUEST_ID: request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - "authenticated_entity": origin, - "servlet_name": request.request_metrics.name, - } - - # Only accept the span context if the origin is authenticated - # and whitelisted - if origin and whitelisted_homeserver(origin): - scope = start_active_span_from_request( - request, "incoming-federation-request", tags=request_tags - ) - else: - scope = start_active_span( - "incoming-federation-request", tags=request_tags - ) - - with scope: - opentracing.inject_response_headers(request.responseHeaders) - - if origin and self.RATELIMIT: - with ratelimiter.ratelimit(origin) as d: - await d - if request._disconnected: - logger.warning( - "client disconnected before we started processing " - "request" - ) - return -1, None - response = await func( - origin, content, request.args, *args, **kwargs - ) - else: - response = await func( - origin, content, request.args, *args, **kwargs - ) - - return response - - return new_func - - def register(self, server): - pattern = re.compile("^" + self.PREFIX + self.PATH + "$") - - for method in ("GET", "PUT", "POST"): - code = getattr(self, "on_%s" % (method), None) - if code is None: - continue - - server.register_paths( - method, - (pattern,), - self._wrap(code), - self.__class__.__name__, - ) - - -class BaseFederationServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a federation server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_federation_server() - - -class FederationSendServlet(BaseFederationServerServlet): - PATH = "/send/(?P[^/]*)/?" - - # We ratelimit manually in the handler as we queue up the requests and we - # don't want to fill up the ratelimiter with blocked requests. - RATELIMIT = False - - # This is when someone is trying to send us a bunch of data. - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - transaction_id: str, - ) -> Tuple[int, JsonDict]: - """Called on PUT /send// - - Args: - transaction_id: The transaction_id associated with this request. This - is *not* None. - - Returns: - Tuple of `(code, response)`, where - `response` is a python dict to be converted into JSON that is - used as the response body. - """ - # Parse the request - try: - transaction_data = content - - logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) - - logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, - origin, - len(transaction_data.get("pdus", [])), - len(transaction_data.get("edus", [])), - ) - - # We should ideally be getting this from the security layer. - # origin = body["origin"] - - # Add some extra data to the transaction dict that isn't included - # in the request body. - transaction_data.update( - transaction_id=transaction_id, destination=self.server_name - ) - - except Exception as e: - logger.exception(e) - return 400, {"error": "Invalid transaction"} - - code, response = await self.handler.on_incoming_transaction( - origin, transaction_data - ) - - return code, response - - -class FederationEventServlet(BaseFederationServerServlet): - PATH = "/event/(?P[^/]*)/?" - - # This is when someone asks for a data item for a given server data_id pair. - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - event_id: str, - ) -> Tuple[int, Union[JsonDict, str]]: - return await self.handler.on_pdu_request(origin, event_id) - - -class FederationStateV1Servlet(BaseFederationServerServlet): - PATH = "/state/(?P[^/]*)/?" - - # This is when someone asks for all data for a given room. - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_room_state_request( - origin, - room_id, - parse_string_from_args(query, "event_id", None, required=False), - ) - - -class FederationStateIdsServlet(BaseFederationServerServlet): - PATH = "/state_ids/(?P[^/]*)/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_state_ids_request( - origin, - room_id, - parse_string_from_args(query, "event_id", None, required=True), - ) - - -class FederationBackfillServlet(BaseFederationServerServlet): - PATH = "/backfill/(?P[^/]*)/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - versions = [x.decode("ascii") for x in query[b"v"]] - limit = parse_integer_from_args(query, "limit", None) - - if not limit: - return 400, {"error": "Did not include limit param"} - - return await self.handler.on_backfill_request(origin, room_id, versions, limit) - - -class FederationQueryServlet(BaseFederationServerServlet): - PATH = "/query/(?P[^/]*)" - - # This is when we receive a server-server Query - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - query_type: str, - ) -> Tuple[int, JsonDict]: - args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} - args["origin"] = origin - return await self.handler.on_query_request(query_type, args) - - -class FederationMakeJoinServlet(BaseFederationServerServlet): - PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - """ - Args: - origin: The authenticated server_name of the calling server - - content: (GETs don't have bodies) - - query: Query params from the request. - - **kwargs: the dict mapping keys to path components as specified in - the path match regexp. - - Returns: - Tuple of (response code, response object) - """ - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") - if supported_versions is None: - supported_versions = ["1"] - - result = await self.handler.on_make_join_request( - origin, room_id, user_id, supported_versions=supported_versions - ) - return 200, result - - -class FederationMakeLeaveServlet(BaseFederationServerServlet): - PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, result - - -class FederationV1SendLeaveServlet(BaseFederationServerServlet): - PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - result = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, (200, result) - - -class FederationV2SendLeaveServlet(BaseFederationServerServlet): - PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, result - - -class FederationMakeKnockServlet(BaseFederationServerServlet): - PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args( - query, "ver", required=True, encoding="utf-8" - ) - - result = await self.handler.on_make_knock_request( - origin, room_id, user_id, supported_versions=supported_versions - ) - return 200, result - - -class FederationV1SendKnockServlet(BaseFederationServerServlet): - PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, result - - -class FederationEventAuthServlet(BaseFederationServerServlet): - PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_event_auth(origin, room_id, event_id) - - -class FederationV1SendJoinServlet(BaseFederationServerServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - # TODO(paul): assert that event_id parsed from path actually - # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) - return 200, (200, result) - - -class FederationV2SendJoinServlet(BaseFederationServerServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - # TODO(paul): assert that event_id parsed from path actually - # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) - return 200, result - - -class FederationV1InviteServlet(BaseFederationServerServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - # We don't get a room version, so we have to assume its EITHER v1 or - # v2. This is "fine" as the only difference between V1 and V2 is the - # state resolution algorithm, and we don't use that for processing - # invites - result = await self.handler.on_invite_request( - origin, content, room_version_id=RoomVersions.V1.identifier - ) - - # V1 federation API is defined to return a content of `[200, {...}]` - # due to a historical bug. - return 200, (200, result) - - -class FederationV2InviteServlet(BaseFederationServerServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - # TODO(paul): assert that room_id/event_id parsed from path actually - # match those given in content - - room_version = content["room_version"] - event = content["event"] - invite_room_state = content["invite_room_state"] - - # Synapse expects invite_room_state to be in unsigned, as it is in v1 - # API - - event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - - result = await self.handler.on_invite_request( - origin, event, room_version_id=room_version - ) - return 200, result - - -class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): - PATH = "/exchange_third_party_invite/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - await self.handler.on_exchange_third_party_invite_request(content) - return 200, {} - - -class FederationClientKeysQueryServlet(BaseFederationServerServlet): - PATH = "/user/keys/query" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - return await self.handler.on_query_client_keys(origin, content) - - -class FederationUserDevicesQueryServlet(BaseFederationServerServlet): - PATH = "/user/devices/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - user_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_query_user_devices(origin, user_id) - - -class FederationClientKeysClaimServlet(BaseFederationServerServlet): - PATH = "/user/keys/claim" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - response = await self.handler.on_claim_client_keys(origin, content) - return 200, response - - -class FederationGetMissingEventsServlet(BaseFederationServerServlet): - # TODO(paul): Why does this path alone end with "/?" optional? - PATH = "/get_missing_events/(?P[^/]*)/?" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - limit = int(content.get("limit", 10)) - earliest_events = content.get("earliest_events", []) - latest_events = content.get("latest_events", []) - - result = await self.handler.on_get_missing_events( - origin, - room_id=room_id, - earliest_events=earliest_events, - latest_events=latest_events, - limit=limit, - ) - - return 200, result - - -class On3pidBindServlet(BaseFederationServerServlet): - PATH = "/3pid/onbind" - - REQUIRE_AUTH = False - - async def on_POST( - self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - if "invites" in content: - last_exception = None - for invite in content["invites"]: - try: - if "signed" not in invite or "token" not in invite["signed"]: - message = ( - "Rejecting received notification of third-" - "party invite without signed: %s" % (invite,) - ) - logger.info(message) - raise SynapseError(400, message) - await self.handler.exchange_third_party_invite( - invite["sender"], - invite["mxid"], - invite["room_id"], - invite["signed"], - ) - except Exception as e: - last_exception = e - if last_exception: - raise last_exception - return 200, {} - - -class OpenIdUserInfo(BaseFederationServerServlet): - """ - Exchange a bearer token for information about a user. - - The response format should be compatible with: - http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse - - GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "sub": "@userpart:example.org", - } - """ - - PATH = "/openid/userinfo" - - REQUIRE_AUTH = False - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - query: Dict[bytes, List[bytes]], - ) -> Tuple[int, JsonDict]: - token = parse_string_from_args(query, "access_token") - if token is None: - return ( - 401, - {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, - ) - - user_id = await self.handler.on_openid_userinfo(token) - - if user_id is None: - return ( - 401, - { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired", - }, - ) - - return 200, {"sub": user_id} - - -class PublicRoomList(BaseFederationServlet): - """ - Fetch the public room list for this server. - - This API returns information in the same format as /publicRooms on the - client API, but will only ever include local public rooms and hence is - intended for consumption by other homeservers. - - GET /publicRooms HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "chunk": [ - { - "aliases": [ - "#test:localhost" - ], - "guest_can_join": false, - "name": "test room", - "num_joined_members": 3, - "room_id": "!whkydVegtvatLfXmPN:localhost", - "world_readable": false - } - ], - "end": "END", - "start": "START" - } - """ - - PATH = "/publicRooms" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - allow_access: bool, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_list_handler() - self.allow_access = allow_access - - async def on_GET( - self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - if not self.allow_access: - raise FederationDeniedError(origin) - - limit = parse_integer_from_args(query, "limit", 0) - since_token = parse_string_from_args(query, "since", None) - include_all_networks = parse_boolean_from_args( - query, "include_all_networks", default=False - ) - third_party_instance_id = parse_string_from_args( - query, "third_party_instance_id", None - ) - - if include_all_networks: - network_tuple = None - elif third_party_instance_id: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - else: - network_tuple = ThirdPartyInstanceID(None, None) - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - data = await self.handler.get_local_public_room_list( - limit, since_token, network_tuple=network_tuple, from_federation=True - ) - return 200, data - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - # This implements MSC2197 (Search Filtering over Federation) - if not self.allow_access: - raise FederationDeniedError(origin) - - limit: Optional[int] = int(content.get("limit", 100)) - since_token = content.get("since", None) - search_filter = content.get("filter", None) - - include_all_networks = content.get("include_all_networks", False) - third_party_instance_id = content.get("third_party_instance_id", None) - - if include_all_networks: - network_tuple = None - if third_party_instance_id is not None: - raise SynapseError( - 400, "Can't use include_all_networks with an explicit network" - ) - elif third_party_instance_id is None: - network_tuple = ThirdPartyInstanceID(None, None) - else: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - - if search_filter is None: - logger.warning("Nonefilter") - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - data = await self.handler.get_local_public_room_list( - limit=limit, - since_token=since_token, - search_filter=search_filter, - network_tuple=network_tuple, - from_federation=True, - ) - - return 200, data - - -class FederationVersionServlet(BaseFederationServlet): - PATH = "/version" - - REQUIRE_AUTH = False - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - query: Dict[bytes, List[bytes]], - ) -> Tuple[int, JsonDict]: - return ( - 200, - {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, - ) - - -class BaseGroupsServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_server_handler() - - -class FederationGroupsProfileServlet(BaseGroupsServerServlet): - """Get/set the basic profile of a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/profile" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_profile(group_id, requester_user_id) - - return 200, new_content - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryServlet(BaseGroupsServerServlet): - PATH = "/groups/(?P[^/]*)/summary" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_summary(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsRoomsServlet(BaseGroupsServerServlet): - """Get the rooms in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/rooms" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): - """Add/remove room from group""" - - PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, new_content - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, new_content - - -class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): - """Update room config in group""" - - PATH = ( - "/groups/(?P[^/]*)/room/(?P[^/]*)" - "/config/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - config_key: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - result = await self.handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class FederationGroupsUsersServlet(BaseGroupsServerServlet): - """Get the users in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_users_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): - """Get the users that have been invited to a group""" - - PATH = "/groups/(?P[^/]*)/invited_users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, new_content - - -class FederationGroupsInviteServlet(BaseGroupsServerServlet): - """Ask a group server to invite someone to the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): - """Accept an invitation from the group server""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.accept_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsJoinServlet(BaseGroupsServerServlet): - """Attempt to join a group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.join_group(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): - """Leave or kick a user from the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class BaseGroupsLocalServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups local handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_local_handler() - - -class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): - """A group server has invited a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "group_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group invites." - - new_content = await self.handler.on_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): - """A group server has removed a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, None]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group removals." - - await self.handler.user_removed_from_group(group_id, user_id, content) - - return 200, None - - -class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation""" - - PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_attestation_renewer() - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # We don't need to check auth here as we check the attestation signatures - - new_content = await self.handler.on_renew_attestation( - group_id, user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): - """Add/remove a room from the group summary, with optional category. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError( - 400, "category_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): - """Get all categories for a group""" - - PATH = "/groups/(?P[^/]*)/categories/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_categories(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsCategoryServlet(BaseGroupsServerServlet): - """Add/remove/get a category in a group""" - - PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - -class FederationGroupsRolesServlet(BaseGroupsServerServlet): - """Get roles in a group""" - - PATH = "/groups/(?P[^/]*)/roles/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_roles(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsRoleServlet(BaseGroupsServerServlet): - """Add/remove/get a role in a group""" - - PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError( - 400, "role_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_role( - group_id, requester_user_id, role_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_role( - group_id, requester_user_id, role_id - ) - - return 200, resp - - -class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): - """Add/remove a user from the group summary, with optional role. - - Matches both: - - /groups/:group/summary/users/:user_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): - """Get roles in a group""" - - PATH = "/get_groups_publicised" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - resp = await self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False - ) - - return 200, resp - - -class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): - """Sets whether a group is joinable without an invite or knock""" - - PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - -class RoomComplexityServlet(BaseFederationServlet): - """ - Indicates to other servers how complex (and therefore likely - resource-intensive) a public room this server knows about is. - """ - - PATH = "/rooms/(?P[^/]*)/complexity" - PREFIX = FEDERATION_UNSTABLE_PREFIX - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - is_public = await self._store.is_room_world_readable_or_publicly_joinable( - room_id - ) - - if not is_public: - raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - - complexity = await self._store.get_room_complexity(room_id) - return 200, complexity - - -FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationSendServlet, - FederationEventServlet, - FederationStateV1Servlet, - FederationStateIdsServlet, - FederationBackfillServlet, - FederationQueryServlet, - FederationMakeJoinServlet, - FederationMakeLeaveServlet, - FederationEventServlet, - FederationV1SendJoinServlet, - FederationV2SendJoinServlet, - FederationV1SendLeaveServlet, - FederationV2SendLeaveServlet, - FederationV1InviteServlet, - FederationV2InviteServlet, - FederationGetMissingEventsServlet, - FederationEventAuthServlet, - FederationClientKeysQueryServlet, - FederationUserDevicesQueryServlet, - FederationClientKeysClaimServlet, - FederationThirdPartyInviteExchangeServlet, - On3pidBindServlet, - FederationVersionServlet, - RoomComplexityServlet, - FederationSpaceSummaryServlet, - FederationV1SendKnockServlet, - FederationMakeKnockServlet, -) - -OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,) - -ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,) - -GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsProfileServlet, - FederationGroupsSummaryServlet, - FederationGroupsRoomsServlet, - FederationGroupsUsersServlet, - FederationGroupsInvitedUsersServlet, - FederationGroupsInviteServlet, - FederationGroupsAcceptInviteServlet, - FederationGroupsJoinServlet, - FederationGroupsRemoveUserServlet, - FederationGroupsSummaryRoomsServlet, - FederationGroupsCategoriesServlet, - FederationGroupsCategoryServlet, - FederationGroupsRolesServlet, - FederationGroupsRoleServlet, - FederationGroupsSummaryUsersServlet, - FederationGroupsAddRoomsServlet, - FederationGroupsAddRoomsConfigServlet, - FederationGroupsSettingJoinPolicyServlet, -) - - -GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsLocalInviteServlet, - FederationGroupsRemoveLocalUserServlet, - FederationGroupsBulkPublicisedServlet, -) - - -GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsRenewAttestaionServlet, -) - - -DEFAULT_SERVLET_GROUPS = ( - "federation", - "room_list", - "group_server", - "group_local", - "group_attestation", - "openid", -) - - -def register_servlets( - hs: HomeServer, - resource: HttpServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - servlet_groups: Optional[Container[str]] = None, -): - """Initialize and register servlet classes. - - Will by default register all servlets. For custom behaviour, pass in - a list of servlet_groups to register. - - Args: - hs: homeserver - resource: resource class to register to - authenticator: authenticator to use - ratelimiter: ratelimiter to use - servlet_groups: List of servlet groups to register. - Defaults to ``DEFAULT_SERVLET_GROUPS``. - """ - if not servlet_groups: - servlet_groups = DEFAULT_SERVLET_GROUPS - - if "federation" in servlet_groups: - for servletclass in FEDERATION_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "openid" in servlet_groups: - for servletclass in OPENID_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "room_list" in servlet_groups: - for servletclass in ROOM_LIST_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - allow_access=hs.config.allow_public_rooms_over_federation, - ).register(resource) - - if "group_server" in servlet_groups: - for servletclass in GROUP_SERVER_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "group_local" in servlet_groups: - for servletclass in GROUP_LOCAL_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "group_attestation" in servlet_groups: - for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py new file mode 100644 index 000000000000..95176ba6f9e8 --- /dev/null +++ b/synapse/federation/transport/server/__init__.py @@ -0,0 +1,332 @@ +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict, Iterable, List, Optional, Tuple, Type + +from typing_extensions import Literal + +from synapse.api.errors import FederationDeniedError, SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES +from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES +from synapse.federation.transport.server.groups_server import ( + GROUP_SERVER_SERVLET_CLASSES, +) +from synapse.http.server import HttpServer, JsonResource +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_string_from_args, +) +from synapse.server import HomeServer +from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.util.ratelimitutils import FederationRateLimiter + +logger = logging.getLogger(__name__) + + +class TransportLayerServer(JsonResource): + """Handles incoming federation HTTP requests""" + + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): + """Initialize the TransportLayerServer + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs: homeserver + servlet_groups: List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + self.hs = hs + self.clock = hs.get_clock() + self.servlet_groups = servlet_groups + + super().__init__(hs, canonical_json=False) + + self.authenticator = Authenticator(hs) + self.ratelimiter = hs.get_federation_ratelimiter() + + self.register_servlets() + + def register_servlets(self) -> None: + register_servlets( + self.hs, + resource=self, + ratelimiter=self.ratelimiter, + authenticator=self.authenticator, + servlet_groups=self.servlet_groups, + ) + + +class PublicRoomList(BaseFederationServlet): + """ + Fetch the public room list for this server. + + This API returns information in the same format as /publicRooms on the + client API, but will only ever include local public rooms and hence is + intended for consumption by other homeservers. + + GET /publicRooms HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "chunk": [ + { + "aliases": [ + "#test:localhost" + ], + "guest_can_join": false, + "name": "test room", + "num_joined_members": 3, + "room_id": "!whkydVegtvatLfXmPN:localhost", + "world_readable": false + } + ], + "end": "END", + "start": "START" + } + """ + + PATH = "/publicRooms" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() + self.allow_access = hs.config.allow_public_rooms_over_federation + + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + if not self.allow_access: + raise FederationDeniedError(origin) + + limit = parse_integer_from_args(query, "limit", 0) + since_token = parse_string_from_args(query, "since", None) + include_all_networks = parse_boolean_from_args( + query, "include_all_networks", default=False + ) + third_party_instance_id = parse_string_from_args( + query, "third_party_instance_id", None + ) + + if include_all_networks: + network_tuple = None + elif third_party_instance_id: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + else: + network_tuple = ThirdPartyInstanceID(None, None) + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True + ) + return 200, data + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + # This implements MSC2197 (Search Filtering over Federation) + if not self.allow_access: + raise FederationDeniedError(origin) + + limit: Optional[int] = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + + if search_filter is None: + logger.warning("Nonefilter") + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + network_tuple=network_tuple, + from_federation=True, + ) + + return 200, data + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation""" + + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # We don't need to check auth here as we check the attestation signatures + + new_content = await self.handler.on_renew_attestation( + group_id, user_id, content + ) + + return 200, new_content + + +class OpenIdUserInfo(BaseFederationServlet): + """ + Exchange a bearer token for information about a user. + + The response format should be compatible with: + http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse + + GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "@userpart:example.org", + } + """ + + PATH = "/openid/userinfo" + + REQUIRE_AUTH = False + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + token = parse_string_from_args(query, "access_token") + if token is None: + return ( + 401, + {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, + ) + + user_id = await self.handler.on_openid_userinfo(token) + + if user_id is None: + return ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + + return 200, {"sub": user_id} + + +DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { + "federation": FEDERATION_SERVLET_CLASSES, + "room_list": (PublicRoomList,), + "group_server": GROUP_SERVER_SERVLET_CLASSES, + "group_local": GROUP_LOCAL_SERVLET_CLASSES, + "group_attestation": (FederationGroupsRenewAttestaionServlet,), + "openid": (OpenIdUserInfo,), +} + + +def register_servlets( + hs: HomeServer, + resource: HttpServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + servlet_groups: Optional[Iterable[str]] = None, +): + """Initialize and register servlet classes. + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs: homeserver + resource: resource class to register to + authenticator: authenticator to use + ratelimiter: ratelimiter to use + servlet_groups: List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + if not servlet_groups: + servlet_groups = DEFAULT_SERVLET_GROUPS.keys() + + for servlet_group in servlet_groups: + # Skip unknown servlet groups. + if servlet_group not in DEFAULT_SERVLET_GROUPS: + raise RuntimeError( + f"Attempting to register unknown federation servlet: '{servlet_group}'" + ) + + for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + servletclass( + hs=hs, + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py new file mode 100644 index 000000000000..624c859f1e70 --- /dev/null +++ b/synapse/federation/transport/server/_base.py @@ -0,0 +1,328 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import re + +from synapse.api.errors import Codes, FederationDeniedError, SynapseError +from synapse.api.urls import FEDERATION_V1_PREFIX +from synapse.http.servlet import parse_json_object_from_request +from synapse.logging import opentracing +from synapse.logging.context import run_in_background +from synapse.logging.opentracing import ( + SynapseTags, + start_active_span, + start_active_span_from_request, + tags, + whitelisted_homeserver, +) +from synapse.server import HomeServer +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import parse_and_validate_server_name + +logger = logging.getLogger(__name__) + + +class AuthenticationError(SynapseError): + """There was a problem authenticating the request""" + + +class NoAuthenticationError(AuthenticationError): + """The request had no authentication information""" + + +class Authenticator: + def __init__(self, hs: HomeServer): + self._clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifier = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() + + # A method just so we can pass 'self' as the authenticator to the Servlets + async def authenticate_request(self, request, content): + now = self._clock.time_msec() + json_request = { + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), + "destination": self.server_name, + "signatures": {}, + } + + if content is not None: + json_request["content"] = content + + origin = None + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if not auth_headers: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + for auth in auth_headers: + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) + json_request["origin"] = origin + json_request["signatures"].setdefault(origin, {})[key] = sig + + if ( + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + + if origin is None or not json_request["signatures"]: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + await self.keyring.verify_json_for_server( + origin, + json_request, + now, + ) + + logger.debug("Request from %s", origin) + request.requester = origin + + # If we get a valid signed request from the other side, its probably + # alive + retry_timings = await self.store.get_destination_retry_timings(origin) + if retry_timings and retry_timings.retry_last_ts: + run_in_background(self._reset_retry_timings, origin) + + return origin + + async def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifier.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + + +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode("utf-8") + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith('"'): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warning( + "Error parsing auth header '%s': %s", + header_bytes.decode("ascii", "replace"), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED + ) + + +class BaseFederationServlet: + """Abstract base class for federation servlet classes. + + The servlet object should have a PATH attribute which takes the form of a regexp to + match against the request path (excluding the /federation/v1 prefix). + + The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match + the appropriate HTTP method. These methods must be *asynchronous* and have the + signature: + + on_(self, origin, content, query, **kwargs) + + With arguments: + + origin (unicode|None): The authenticated server_name of the calling server, + unless REQUIRE_AUTH is set to False and authentication failed. + + content (unicode|None): decoded json body of the request. None if the + request was a GET. + + query (dict[bytes, list[bytes]]): Query params from the request. url-decoded + (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded + yet. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Optional[Tuple[int, object]]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + + Raises: + SynapseError: to return an error code + + Exception: other exceptions will be caught, logged, and a 500 will be + returned. + """ + + PATH = "" # Overridden in subclasses, the regex to match against the path. + + REQUIRE_AUTH = True + + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + + RATELIMIT = True # Whether to rate limit requests or not + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs + self.authenticator = authenticator + self.ratelimiter = ratelimiter + self.server_name = server_name + + def _wrap(self, func): + authenticator = self.authenticator + ratelimiter = self.ratelimiter + + @functools.wraps(func) + async def new_func(request, *args, **kwargs): + """A callback which can be passed to HttpServer.RegisterPaths + + Args: + request (twisted.web.http.Request): + *args: unused? + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Tuple[int, object]|None: (response code, response object) as returned by + the callback method. None if the request has already been handled. + """ + content = None + if request.method in [b"PUT", b"POST"]: + # TODO: Handle other method types? other content types? + content = parse_json_object_from_request(request) + + try: + origin = await authenticator.authenticate_request(request, content) + except NoAuthenticationError: + origin = None + if self.REQUIRE_AUTH: + logger.warning( + "authenticate_request failed: missing authentication" + ) + raise + except Exception as e: + logger.warning("authenticate_request failed: %s", e) + raise + + request_tags = { + SynapseTags.REQUEST_ID: request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + "authenticated_entity": origin, + "servlet_name": request.request_metrics.name, + } + + # Only accept the span context if the origin is authenticated + # and whitelisted + if origin and whitelisted_homeserver(origin): + scope = start_active_span_from_request( + request, "incoming-federation-request", tags=request_tags + ) + else: + scope = start_active_span( + "incoming-federation-request", tags=request_tags + ) + + with scope: + opentracing.inject_response_headers(request.responseHeaders) + + if origin and self.RATELIMIT: + with ratelimiter.ratelimit(origin) as d: + await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return -1, None + response = await func( + origin, content, request.args, *args, **kwargs + ) + else: + response = await func( + origin, content, request.args, *args, **kwargs + ) + + return response + + return new_func + + def register(self, server): + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") + + for method in ("GET", "PUT", "POST"): + code = getattr(self, "on_%s" % (method), None) + if code is None: + continue + + server.register_paths( + method, + (pattern,), + self._wrap(code), + self.__class__.__name__, + ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py new file mode 100644 index 000000000000..2fdf6cc99e49 --- /dev/null +++ b/synapse/federation/transport/server/federation.py @@ -0,0 +1,706 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + +from typing_extensions import Literal + +import synapse +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_string_from_args, + parse_strings_from_args, +) +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger(__name__) + + +class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): + PATH = "/send/(?P[^/]*)/?" + + # We ratelimit manually in the handler as we queue up the requests and we + # don't want to fill up the ratelimiter with blocked requests. + RATELIMIT = False + + # This is when someone is trying to send us a bunch of data. + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: + """Called on PUT /send// + + Args: + transaction_id: The transaction_id associated with this request. This + is *not* None. + + Returns: + Tuple of `(code, response)`, where + `response` is a python dict to be converted into JSON that is + used as the response body. + """ + # Parse the request + try: + transaction_data = content + + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) + + logger.info( + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", + transaction_id, + origin, + len(transaction_data.get("pdus", [])), + len(transaction_data.get("edus", [])), + ) + + except Exception as e: + logger.exception(e) + return 400, {"error": "Invalid transaction"} + + code, response = await self.handler.on_incoming_transaction( + origin, transaction_id, self.server_name, transaction_data + ) + + return code, response + + +class FederationEventServlet(BaseFederationServerServlet): + PATH = "/event/(?P[^/]*)/?" + + # This is when someone asks for a data item for a given server data_id pair. + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: + return await self.handler.on_pdu_request(origin, event_id) + + +class FederationStateV1Servlet(BaseFederationServerServlet): + PATH = "/state/(?P[^/]*)/?" + + # This is when someone asks for all data for a given room. + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_room_state_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=False), + ) + + +class FederationStateIdsServlet(BaseFederationServerServlet): + PATH = "/state_ids/(?P[^/]*)/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_state_ids_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=True), + ) + + +class FederationBackfillServlet(BaseFederationServerServlet): + PATH = "/backfill/(?P[^/]*)/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + versions = [x.decode("ascii") for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) + + if not limit: + return 400, {"error": "Did not include limit param"} + + return await self.handler.on_backfill_request(origin, room_id, versions, limit) + + +class FederationQueryServlet(BaseFederationServerServlet): + PATH = "/query/(?P[^/]*)" + + # This is when we receive a server-server Query + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: + args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} + args["origin"] = origin + return await self.handler.on_query_request(query_type, args) + + +class FederationMakeJoinServlet(BaseFederationServerServlet): + PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + """ + Args: + origin: The authenticated server_name of the calling server + + content: (GETs don't have bodies) + + query: Query params from the request. + + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. + + Returns: + Tuple of (response code, response object) + """ + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: + supported_versions = ["1"] + + result = await self.handler.on_make_join_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, result + + +class FederationMakeLeaveServlet(BaseFederationServerServlet): + PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result + + +class FederationV1SendLeaveServlet(BaseFederationServerServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, result) + + +class FederationV2SendLeaveServlet(BaseFederationServerServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, result + + +class FederationMakeKnockServlet(BaseFederationServerServlet): + PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) + + result = await self.handler.on_make_knock_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, result + + +class FederationV1SendKnockServlet(BaseFederationServerServlet): + PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result + + +class FederationEventAuthServlet(BaseFederationServerServlet): + PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_event_auth(origin, room_id, event_id) + + +class FederationV1SendJoinServlet(BaseFederationServerServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # TODO(paul): assert that event_id parsed from path actually + # match those given in content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, (200, result) + + +class FederationV2SendJoinServlet(BaseFederationServerServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that event_id parsed from path actually + # match those given in content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, result + + +class FederationV1InviteServlet(BaseFederationServerServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + result = await self.handler.on_invite_request( + origin, content, room_version_id=RoomVersions.V1.identifier + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + return 200, (200, result) + + +class FederationV2InviteServlet(BaseFederationServerServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that room_id/event_id parsed from path actually + # match those given in content + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + result = await self.handler.on_invite_request( + origin, event, room_version_id=room_version + ) + return 200, result + + +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): + PATH = "/exchange_third_party_invite/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + await self.handler.on_exchange_third_party_invite_request(content) + return 200, {} + + +class FederationClientKeysQueryServlet(BaseFederationServerServlet): + PATH = "/user/keys/query" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + return await self.handler.on_query_client_keys(origin, content) + + +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): + PATH = "/user/devices/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_query_user_devices(origin, user_id) + + +class FederationClientKeysClaimServlet(BaseFederationServerServlet): + PATH = "/user/keys/claim" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + response = await self.handler.on_claim_client_keys(origin, content) + return 200, response + + +class FederationGetMissingEventsServlet(BaseFederationServerServlet): + # TODO(paul): Why does this path alone end with "/?" optional? + PATH = "/get_missing_events/(?P[^/]*)/?" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + limit = int(content.get("limit", 10)) + earliest_events = content.get("earliest_events", []) + latest_events = content.get("latest_events", []) + + result = await self.handler.on_get_missing_events( + origin, + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + ) + + return 200, result + + +class On3pidBindServlet(BaseFederationServerServlet): + PATH = "/3pid/onbind" + + REQUIRE_AUTH = False + + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + if "invites" in content: + last_exception = None + for invite in content["invites"]: + try: + if "signed" not in invite or "token" not in invite["signed"]: + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) + logger.info(message) + raise SynapseError(400, message) + await self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) + except Exception as e: + last_exception = e + if last_exception: + raise last_exception + return 200, {} + + +class FederationVersionServlet(BaseFederationServlet): + PATH = "/version" + + REQUIRE_AUTH = False + + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + return ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + + +class FederationSpaceSummaryServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/spaces/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + + max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) + + return 200, await self.handler.federation_space_summary( + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + # TODO When switching to the stable endpoint, remove the POST handler. + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + exclude_rooms = content.get("exclude_rooms", []) + if not isinstance(exclude_rooms, list) or any( + not isinstance(x, str) for x in exclude_rooms + ): + raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + return 200, await self.handler.federation_space_summary( + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + +class FederationRoomHierarchyServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/hierarchy/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + return 200, await self.handler.get_federation_hierarchy( + origin, room_id, suggested_only + ) + + +class RoomComplexityServlet(BaseFederationServlet): + """ + Indicates to other servers how complex (and therefore likely + resource-intensive) a public room this server knows about is. + """ + + PATH = "/rooms/(?P[^/]*)/complexity" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if not is_public: + raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) + + complexity = await self._store.get_room_complexity(room_id) + return 200, complexity + + +FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationSendServlet, + FederationEventServlet, + FederationStateV1Servlet, + FederationStateIdsServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationMakeLeaveServlet, + FederationEventServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, + FederationGetMissingEventsServlet, + FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, + FederationClientKeysClaimServlet, + FederationThirdPartyInviteExchangeServlet, + On3pidBindServlet, + FederationVersionServlet, + RoomComplexityServlet, + FederationSpaceSummaryServlet, + FederationRoomHierarchyServlet, + FederationV1SendKnockServlet, + FederationMakeKnockServlet, +) diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py new file mode 100644 index 000000000000..a12cd18d5806 --- /dev/null +++ b/synapse/federation/transport/server/groups_local.py @@ -0,0 +1,113 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Type + +from synapse.api.errors import SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.server import HomeServer +from synapse.types import JsonDict, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter + + +class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): + """A group server has invited a local user""" + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + + new_content = await self.handler.on_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): + """A group server has removed a local user""" + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + + await self.handler.user_removed_from_group(group_id, user_id, content) + + return 200, None + + +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): + """Get roles in a group""" + + PATH = "/get_groups_publicised" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + resp = await self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False + ) + + return 200, resp + + +GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py new file mode 100644 index 000000000000..b30e92a5eb74 --- /dev/null +++ b/synapse/federation/transport/server/groups_server.py @@ -0,0 +1,753 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Type + +from typing_extensions import Literal + +from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH +from synapse.api.errors import Codes, SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.http.servlet import parse_string_from_args +from synapse.server import HomeServer +from synapse.types import JsonDict, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter + + +class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): + """Get/set the basic profile of a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/profile" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_profile(group_id, requester_user_id) + + return 200, new_content + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): + PATH = "/groups/(?P[^/]*)/summary" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_summary(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): + """Get the rooms in a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/rooms" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): + """Add/remove room from group""" + + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + return 200, new_content + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_room_from_group( + group_id, requester_user_id, room_id + ) + + return 200, new_content + + +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): + """Update room config in group""" + + PATH = ( + "/groups/(?P[^/]*)/room/(?P[^/]*)" + "/config/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = await self.handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content + ) + + return 200, result + + +class FederationGroupsUsersServlet(BaseGroupsServerServlet): + """Get the users in a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/users" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_users_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): + """Get the users that have been invited to a group""" + + PATH = "/groups/(?P[^/]*)/invited_users" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + return 200, new_content + + +class FederationGroupsInviteServlet(BaseGroupsServerServlet): + """Ask a group server to invite someone to the group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.invite_to_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): + """Accept an invitation from the group server""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.accept_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsJoinServlet(BaseGroupsServerServlet): + """Attempt to join a group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.join_group(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): + """Leave or kick a user from the group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/categories/(?P[^/]+))?" + "/rooms/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError( + 400, "category_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_summary_room( + group_id, + requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_summary_room( + group_id, requester_user_id, room_id=room_id, category_id=category_id + ) + + return 200, resp + + +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): + """Get all categories for a group""" + + PATH = "/groups/(?P[^/]*)/categories/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_categories(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): + """Add/remove/get a category in a group""" + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + +class FederationGroupsRolesServlet(BaseGroupsServerServlet): + """Get roles in a group""" + + PATH = "/groups/(?P[^/]*)/roles/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_roles(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsRoleServlet(BaseGroupsServerServlet): + """Add/remove/get a role in a group""" + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) + + return 200, resp + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError( + 400, "role_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_role( + group_id, requester_user_id, role_id, content + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_role( + group_id, requester_user_id, role_id + ) + + return 200, resp + + +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/roles/(?P[^/]+))?" + "/users/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_summary_user( + group_id, + requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_summary_user( + group_id, requester_user_id, user_id=user_id, role_id=role_id + ) + + return 200, resp + + +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): + """Sets whether a group is joinable without an invite or knock""" + + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + return 200, new_content + + +GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, +) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index c83a261918c0..b9b12fbea563 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,18 +17,17 @@ """ import logging -from typing import Optional +from typing import List, Optional import attr from synapse.types import JsonDict -from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) -@attr.s(slots=True) -class Edu(JsonEncodedObject): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Edu: """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are @@ -36,10 +35,10 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - edu_type = attr.ib(type=str) - content = attr.ib(type=dict) - origin = attr.ib(type=str) - destination = attr.ib(type=str) + edu_type: str + content: dict + origin: str + destination: str def get_dict(self) -> JsonDict: return { @@ -55,14 +54,21 @@ def get_internal_dict(self) -> JsonDict: "destination": self.destination, } - def get_context(self): + def get_context(self) -> str: return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") - def strip_context(self): + def strip_context(self) -> None: getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" -class Transaction(JsonEncodedObject): +def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]: + if edus is None: + return [] + return edus + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Transaction: """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. @@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject): """ - valid_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "previous_ids", - "pdus", - "edus", - ] - - internal_keys = ["transaction_id", "destination"] - - required_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "pdus", - ] - - def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): - """If we include a list of pdus then we decode then as PDU's - automatically. - """ - - # If there's no EDUs then remove the arg - if "edus" in kwargs and not kwargs["edus"]: - del kwargs["edus"] - - super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) - - @staticmethod - def create_new(pdus, **kwargs): - """Used to create a new transaction. Will auto fill out - transaction_id and origin_server_ts keys. - """ - if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") - if "transaction_id" not in kwargs: - raise KeyError("Require 'transaction_id' to construct a Transaction") - - kwargs["pdus"] = [p.get_pdu_json() for p in pdus] - - return Transaction(**kwargs) + # Required keys. + transaction_id: str + origin: str + destination: str + origin_server_ts: int + pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + + def get_dict(self) -> JsonDict: + """A JSON-ready dictionary of valid keys which aren't internal.""" + result = { + "origin": self.origin, + "origin_server_ts": self.origin_server_ts, + "pdus": self.pdus, + } + if self.edus: + result["edus"] = self.edus + return result diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 21a17cd2e834..4ab4046650b8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -392,9 +392,6 @@ async def get_3pe_protocols( protocols[p].append(info) def _merge_instances(infos: List[JsonDict]) -> JsonDict: - if not infos: - return {} - # Merge the 'instances' lists of multiple results, but just take # the other fields from the first as they ought to be identical # copy the result so as not to corrupt the cached one @@ -406,7 +403,9 @@ def _merge_instances(infos: List[JsonDict]) -> JsonDict: return combined - return {p: _merge_instances(protocols[p]) for p in protocols.keys()} + return { + p: _merge_instances(protocols[p]) for p in protocols.keys() if protocols[p] + } async def _get_services_for_event( self, event: EventBase diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 22a855224188..161b3c933c5d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -73,7 +73,7 @@ from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: - from synapse.rest.client.v1.login import LoginResponse + from synapse.rest.client.login import LoginResponse from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -461,7 +461,7 @@ async def check_ui_auth( If no auth flows have been completed successfully, raises an InteractiveAuthIncompleteError. To handle this, you can use - synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + synapse.rest.client._base.interactive_auth_handler as a decorator. Args: @@ -543,7 +543,7 @@ async def check_ui_auth( # Note that the registration endpoint explicitly removes the # "initial_device_display_name" parameter if it is provided # without a "password" parameter. See the changes to - # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST + # synapse.rest.client.register.RegisterRestServlet.on_POST # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9. if not clientdict: clientdict = session.clientdict diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index e2410e482f8e..4288ffff094a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -213,7 +213,7 @@ async def check_restricted_join_rules( raise AuthError( 403, - "You do not belong to any of the required rooms to join this room.", + "You do not belong to any of the required rooms/spaces to join this room.", ) async def has_restricted_join_rules( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8197b60b7673..c0e13bdaac1d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -42,6 +42,7 @@ from synapse import event_auth from synapse.api.constants import ( + EventContentFields, EventTypes, Membership, RejectedReason, @@ -108,21 +109,33 @@ ) -@attr.s(slots=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _NewEventInfo: """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event - state: the state at that event + state: the state at that event, according to /state_ids from a remote + homeserver. Only populated for backfilled events which are going to be a + new backwards extremity. + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. - auth_events: the auth_event map for that event """ - event = attr.ib(type=EventBase) - state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) + event: EventBase + state: Optional[Sequence[EventBase]] + claimed_auth_event_map: StateMap[EventBase] class FederationHandler(BaseHandler): @@ -207,8 +220,6 @@ async def on_receive_pdu( room_id = pdu.room_id event_id = pdu.event_id - logger.info("handling received PDU: %s", pdu) - # We reprocess pdus when we have seen them only as outliers existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True @@ -216,14 +227,19 @@ async def on_receive_pdu( # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - - already_seen = existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - if already_seen: - logger.debug("Already seen pdu") - return + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", event_id + ) + return + if pdu.internal_metadata.is_outlier(): + logger.info( + "Ignoring received outlier %s which we already have as an outlier", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) # do some initial sanity-checking of the event. In particular, make # sure it doesn't have hundreds of prev_events or auth_events, which @@ -262,7 +278,12 @@ async def on_receive_pdu( state = None - # Get missing pdus if necessary. + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) @@ -313,7 +334,8 @@ async def on_receive_pdu( "Found all missing prev_events", ) - if prevs - seen: + missing_prevs = prevs - seen + if missing_prevs: # We've still not been able to get all of the prev_events for this event. # # In this case, we need to fall back to asking another server in the @@ -341,8 +363,8 @@ async def on_receive_pdu( if sent_to_us_directly: logger.warning( "Rejecting: failed to fetch %d prev events: %s", - len(prevs - seen), - shortstr(prevs - seen), + len(missing_prevs), + shortstr(missing_prevs), ) raise FederationError( "ERROR", @@ -355,9 +377,10 @@ async def on_receive_pdu( ) logger.info( - "Event %s is missing prev_events: calculating state for a " + "Event %s is missing prev_events %s: calculating state for a " "backwards extremity", event_id, + shortstr(missing_prevs), ) # Calculate the state after each of the previous events, and @@ -375,7 +398,7 @@ async def on_receive_pdu( # Ask the remote server for the states we don't # know about - for p in prevs - seen: + for p in missing_prevs: logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): @@ -432,6 +455,13 @@ async def on_receive_pdu( affected=event_id, ) + # A second round of checks for all events. Check that the event passes auth + # based on `auth_events`, this allows us to assert that the event would + # have been allowed at some point. If an event passes this check its OK + # for it to be used as part of a returned `/state` request, as either + # a) we received the event as part of the original join and so trust it, or + # b) we'll do a state resolution with existing state before it becomes + # part of the "current state", which adds more protection. await self._process_received_pdu(origin, pdu, state=state) async def _get_missing_events_for_pdu( @@ -531,21 +561,14 @@ async def _get_missing_events_for_pdu( logger.warning("Failed to get prev_events: %s", e) return - logger.info( - "Got %d prev_events: %s", - len(missing_events), - shortstr(missing_events), - ) + logger.info("Got %d prev_events", len(missing_events)) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) for ev in missing_events: - logger.info( - "Handling received prev_event %s", - ev.event_id, - ) + logger.info("Handling received prev_event %s", ev) with nested_logging_context(ev.event_id): try: await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) @@ -889,6 +912,79 @@ async def _process_received_pdu( "resync_device_due_to_pdu", self._resync_device, event.sender ) + await self._handle_marker_event(origin, event) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + async def _resync_device(self, sender: str) -> None: """We have detected that the device list for the given user may be out of sync, so we try and resync them. @@ -1000,7 +1096,7 @@ async def backfill( _NewEventInfo( event=ev, state=events_to_state[e_id], - auth_events={ + claimed_auth_event_map={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -1057,9 +1153,19 @@ async def maybe_backfill( async def _maybe_backfill_inner( self, room_id: str, current_depth: int, limit: int ) -> bool: - extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) + oldest_events_with_depth = ( + await self.store.get_oldest_event_ids_with_depth_in_room(room_id) + ) + insertion_events_to_be_backfilled = ( + await self.store.get_insertion_event_backwards_extremities_in_room(room_id) + ) + logger.debug( + "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", + oldest_events_with_depth, + insertion_events_to_be_backfilled, + ) - if not extremities: + if not oldest_events_with_depth and not insertion_events_to_be_backfilled: logger.debug("Not backfilling as no extremeties found.") return False @@ -1089,10 +1195,12 @@ async def _maybe_backfill_inner( # state *before* the event, ignoring the special casing certain event # types have. - forward_events = await self.store.get_successor_events(list(extremities)) + forward_event_ids = await self.store.get_successor_events( + list(oldest_events_with_depth) + ) extremities_events = await self.store.get_events( - forward_events, + forward_event_ids, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, ) @@ -1106,10 +1214,19 @@ async def _maybe_backfill_inner( redact=False, check_history_visibility_only=True, ) + logger.debug( + "_maybe_backfill_inner: filtered_extremities %s", filtered_extremities + ) - if not filtered_extremities: + if not filtered_extremities and not insertion_events_to_be_backfilled: return False + extremities = { + **oldest_events_with_depth, + # TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks + **insertion_events_to_be_backfilled, + } + # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] @@ -1643,10 +1760,8 @@ async def _handle_queued_pdus( for p, origin in room_queue: try: logger.info( - "Processing queued PDU %s which was received " - "while we were joining %s", - p.event_id, - p.room_id, + "Processing queued PDU %s which was received while we were joining", + p, ) with nested_logging_context(p.event_id): await self.on_receive_pdu(origin, p, sent_to_us_directly=True) @@ -2208,7 +2323,7 @@ async def _auth_and_persist_event( event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> None: """ @@ -2220,17 +2335,18 @@ async def _auth_and_persist_event( context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + backfilled: True if the event was backfilled. """ context = await self._check_event_auth( @@ -2238,7 +2354,7 @@ async def _auth_and_persist_event( event, context, state=state, - auth_events=auth_events, + claimed_auth_event_map=claimed_auth_event_map, backfilled=backfilled, ) @@ -2302,7 +2418,7 @@ async def prep(ev_info: _NewEventInfo): event, res, state=ev_info.state, - auth_events=ev_info.auth_events, + claimed_auth_event_map=ev_info.claimed_auth_event_map, backfilled=backfilled, ) return res @@ -2568,7 +2684,7 @@ async def _check_event_auth( event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -2580,21 +2696,19 @@ async def _check_event_auth( context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. - Also NB that this function adds entries to it. + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. - If this is not provided, it is calculated from the previous state IDs. backfilled: True if the event was backfilled. Returns: @@ -2603,7 +2717,12 @@ async def _check_event_auth( room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if not auth_events: + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -2611,18 +2730,11 @@ async def _check_event_auth( auth_events_x = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - try: - context = await self._update_auth_events_and_context_for_auth( + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2635,9 +2747,10 @@ async def _check_event_auth( "Ignoring failure and continuing processing of event.", event.event_id, ) + auth_events_for_auth = auth_events try: - event_auth.check(room_version_obj, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -2662,8 +2775,8 @@ async def _update_auth_events_and_context_for_auth( origin: str, event: EventBase, context: EventContext, - auth_events: MutableStateMap[EventBase], - ) -> EventContext: + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it @@ -2680,7 +2793,7 @@ async def _update_auth_events_and_context_for_auth( event: context: - auth_events: + input_auth_events: Map from (event_type, state_key) to event Normally, our calculated auth_events based on the state of the room @@ -2688,11 +2801,12 @@ async def _update_auth_events_and_context_for_auth( event is an outlier), may be the auth events claimed by the remote server. - Also NB that this function adds entries to it. - Returns: - updated context + updated context, updated auth event map """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + event_auth_events = set(event.auth_event_ids()) # missing_auth is the set of the event's auth_events which we don't yet have @@ -2721,7 +2835,7 @@ async def _update_auth_events_and_context_for_auth( # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) - return context + return context, auth_events seen_remotes = await self.store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] @@ -2752,7 +2866,10 @@ async def _update_auth_events_and_context_for_auth( await self.state_handler.compute_event_context(e) ) await self._auth_and_persist_event( - origin, e, missing_auth_event_context, auth_events=auth + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, ) if e.event_id in event_auth_events: @@ -2770,14 +2887,14 @@ async def _update_auth_events_and_context_for_auth( # obviously be empty # (b) alternatively, why don't we do it earlier? logger.info("Skipping auth_event fetch for outlier") - return context + return context, auth_events different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) if not different_auth: - return context + return context, auth_events logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2803,7 +2920,7 @@ async def _update_auth_events_and_context_for_auth( # XXX: should we reject the event in this case? It feels like we should, # but then shouldn't we also do so if we've failed to fetch any of the # auth events? - return context + return context, auth_events # now we state-resolve between our own idea of the auth events, and the remote's # idea of them. @@ -2833,7 +2950,7 @@ async def _update_auth_events_and_context_for_auth( event, context, auth_events ) - return context + return context, auth_events async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0961dec5ab5c..8ffeabacf975 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -824,6 +824,7 @@ async def ask_id_server_for_third_party_invite( room_avatar_url: str, room_join_rules: str, room_name: str, + room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, id_access_token: Optional[str] = None, @@ -843,6 +844,7 @@ async def ask_id_server_for_third_party_invite( notifications. room_join_rules: The join rules of the email (e.g. "public"). room_name: The m.room.name of the room. + room_type: The type of the room from its m.room.create event (e.g "m.space"). inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. @@ -869,6 +871,10 @@ async def ask_id_server_for_third_party_invite( "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + + if room_type is not None: + invite_config["org.matrix.msc3288.room_type"] = room_type + # If a custom web client location is available, include it in the request. if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 016c5df2ca75..7ca14e1d8473 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1184,8 +1184,7 @@ async def set_state( new_fields = {"state": presence} if not ignore_status_msg: - msg = status_msg if presence != PresenceState.OFFLINE else None - new_fields["status_msg"] = msg + new_fields["status_msg"] = status_msg if presence == PresenceState.ONLINE or ( presence == PresenceState.BUSY and self._busy_presence_enabled @@ -1478,7 +1477,7 @@ def format_user_presence_state( content["user_id"] = state.user_id if state.last_active_ts: content["last_active_ago"] = now - state.last_active_ts - if state.status_msg and state.state != PresenceState.OFFLINE: + if state.status_msg: content["status_msg"] = state.status_msg if state.state == PresenceState.ONLINE: content["currently_active"] = state.currently_active @@ -1840,9 +1839,7 @@ def handle_timeout( # don't set them as offline. sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: - state = state.copy_and_replace( - state=PresenceState.OFFLINE, status_msg=None - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE) changed = True else: # We expect to be poked occasionally by the other side. @@ -1850,7 +1847,7 @@ def handle_timeout( # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) + state = state.copy_and_replace(state=PresenceState.OFFLINE) changed = True return state if changed else None diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 21cca94f5d4d..fb495229a7e1 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -70,7 +70,8 @@ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None ) if not is_in_room: logger.info( - "Ignoring receipt from %s as we're not in the room", + "Ignoring receipt for room %r from server %s as we're not in the room", + room_id, origin, ) continue @@ -188,21 +189,21 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: new_users = {} for rr_user_id, user_rr in m_read.items(): try: - hidden = user_rr.get("hidden", None) - if hidden is not True or rr_user_id == user_id: - new_users[rr_user_id] = user_rr.copy() - # If hidden has a value replace hidden with the correct prefixed key - if hidden is not None: - new_users[rr_user_id].pop("hidden") - new_users[rr_user_id][ - ReadReceiptEventFields.MSC2285_HIDDEN - ] = hidden + hidden = user_rr.get("hidden") except AttributeError: # Due to https://github.com/matrix-org/synapse/issues/10376 # there are cases where user_rr is a string, in those cases - # we just copy the user_rr - new_users[rr_user_id] = user_rr - pass + # we just ignore the read receipt + continue + + if hidden is not True or rr_user_id == user_id: + new_users[rr_user_id] = user_rr.copy() + # If hidden has a value replace hidden with the correct prefixed key + if hidden is not None: + new_users[rr_user_id].pop("hidden") + new_users[rr_user_id][ + ReadReceiptEventFields.MSC2285_HIDDEN + ] = hidden # Set new users unless empty if len(new_users.keys()) > 0: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index fae2c098e32e..6d433fad41b0 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -356,6 +356,12 @@ async def get_remote_public_room_list( include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Get the public room list from remote server + + Raises: + SynapseError + """ + if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -395,13 +401,16 @@ async def get_remote_public_room_list( limit = None since_token = None - res = await self._get_remote_list_cached( - server_name, - limit=limit, - since_token=since_token, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + try: + res = await self._get_remote_list_cached( + server_name, + limit=limit, + since_token=since_token, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to fetch room list") if search_filter: res = { @@ -423,20 +432,21 @@ async def _get_remote_list_cached( include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Wrapper around FederationClient.get_public_rooms that caches the + result. + """ + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - try: - return await repl_layer.get_public_rooms( - server_name, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except (RequestSendFailed, HttpResponseException): - raise SynapseError(502, "Failed to fetch room list") + return await repl_layer.get_public_rooms( + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) key = ( server_name, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65ad3efa6a60..ba131962185f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,12 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -1237,6 +1242,11 @@ async def _make_and_store_3pid_invite( if room_name_event: room_name = room_name_event.content.get("name", "") + room_type = None + room_create_event = room_state.get((EventTypes.Create, "")) + if room_create_event: + room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE) + room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: @@ -1263,6 +1273,7 @@ async def _make_and_store_3pid_invite( room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, + room_type=room_type, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url, id_access_token=id_access_token, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py new file mode 100644 index 000000000000..ac6cfc0da915 --- /dev/null +++ b/synapse/handlers/room_summary.py @@ -0,0 +1,1171 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +import re +from collections import deque +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import attr + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RoomTypes, +) +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.events import EventBase +from synapse.events.utils import format_event_for_client_v2 +from synapse.types import JsonDict +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# number of rooms to return. We'll stop once we hit this limit. +MAX_ROOMS = 50 + +# max number of events to return per room. +MAX_ROOMS_PER_SPACE = 50 + +# max number of federation servers to hit per room +MAX_SERVERS_PER_SPACE = 3 + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationKey: + """The key used to find unique pagination session.""" + + # The first three entries match the request parameters (and cannot change + # during a pagination session). + room_id: str + suggested_only: bool + max_depth: Optional[int] + # The randomly generated token. + token: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationSession: + """The information that is stored for pagination.""" + + # The time the pagination session was created, in milliseconds. + creation_time_ms: int + # The queue of rooms which are still to process. + room_queue: List["_RoomQueueEntry"] + # A set of rooms which have been processed. + processed_rooms: Set[str] + + +class RoomSummaryHandler: + # The time a pagination session remains valid for. + _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._event_auth_handler = hs.get_event_auth_handler() + self._store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() + self._server_name = hs.hostname + self._federation_client = hs.get_federation_client() + + # A map of query information to the current pagination state. + # + # TODO Allow for multiple workers to share this data. + # TODO Expire pagination tokens. + self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} + + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + self._pagination_response_cache: ResponseCache[ + Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + ] = ResponseCache( + hs.get_clock(), + "get_room_hierarchy", + ) + + def _expire_pagination_sessions(self): + """Expire pagination session which are old.""" + expire_before = ( + self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS + ) + to_expire = [] + + for key, value in self._pagination_sessions.items(): + if value.creation_time_ms < expire_before: + to_expire.append(key) + + for key in to_expire: + logger.debug("Expiring pagination session id %s", key) + del self._pagination_sessions[key] + + async def get_space_summary( + self, + requester: str, + room_id: str, + suggested_only: bool = False, + max_rooms_per_space: Optional[int] = None, + ) -> JsonDict: + """ + Implementation of the space summary C-S API + + Args: + requester: user id of the user making this request + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. This does not apply to the root room (ie, room_id), and + is overridden by MAX_ROOMS_PER_SPACE. + + Returns: + summary dict to return + """ + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, room_id), + ) + + # the queue of rooms to process + room_queue = deque((_RoomQueueEntry(room_id, ()),)) + + # rooms we have already processed + processed_rooms: Set[str] = set() + + # events we have already processed. We don't necessarily have their event ids, + # so instead we key on (room id, state key) + processed_events: Set[Tuple[str, str]] = set() + + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] + + while room_queue and len(rooms_result) < MAX_ROOMS: + queue_entry = room_queue.popleft() + room_id = queue_entry.room_id + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + # The client-specified max_rooms_per_space limit doesn't apply to the + # room_id specified in the request, so we ignore it if this is the + # first room we are processing. + max_children = max_rooms_per_space if processed_rooms else None + + if is_in_room: + room_entry = await self._summarize_local_room( + requester, None, room_id, suggested_only, max_children + ) + + events: Sequence[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children_state_events + + logger.debug( + "Query of local room %s returned events %s", + room_id, + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + else: + fed_rooms = await self._summarize_remote_room( + queue_entry, + suggested_only, + max_children, + exclude_rooms=processed_rooms, + ) + + # The results over federation might include rooms that the we, + # as the requesting server, are allowed to see, but the requesting + # user is not permitted see. + # + # Filter the returned results to only what is accessible to the user. + events = [] + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id + + # The user can see the room, include it! + if await self._is_remote_room_accessible( + requester, fed_room_id, room + ): + # Before returning to the client, remove the allowed_room_ids + # and allowed_spaces keys. + room.pop("allowed_room_ids", None) + room.pop("allowed_spaces", None) + + rooms_result.append(room) + events.extend(room_entry.children_state_events) + + # All rooms returned don't need visiting again (even if the user + # didn't have access to them). + processed_rooms.add(fed_room_id) + + logger.debug( + "Query of %s returned rooms %s, events %s", + room_id, + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + + # the room we queried may or may not have been returned, but don't process + # it again, anyway. + processed_rooms.add(room_id) + + # XXX: is it ok that we blindly iterate through any events returned by + # a remote server, whether or not they actually link to any rooms in our + # tree? + for ev in events: + # remote servers might return events we have already processed + # (eg, Dendrite returns inward pointers as well as outward ones), so + # we need to filter them out, to avoid returning duplicate links to the + # client. + ev_key = (ev["room_id"], ev["state_key"]) + if ev_key in processed_events: + continue + events_result.append(ev) + + # add the child to the queue. we have already validated + # that the vias are a list of server names. + room_queue.append( + _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) + ) + processed_events.add(ev_key) + + return {"rooms": rooms_result, "events": events_result} + + async def get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """ + Implementation of the room hierarchy C-S API. + + Args: + requester: The user ID of the user making this request. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: Whether we should only return children with the "suggested" + flag set. + max_depth: The maximum depth in the tree to explore, must be a + non-negative integer. + + 0 would correspond to just the root room, 1 would include just + the root room's children, etc. + limit: An optional limit on the number of rooms to return per + page. Must be a positive integer. + from_token: An optional pagination token. + + Returns: + The JSON hierarchy dictionary. + """ + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + # + # This is due to the pagination process mutating internal state, attempting + # to process multiple requests for the same page will result in errors. + return await self._pagination_response_cache.wrap( + (requested_room_id, suggested_only, max_depth, limit, from_token), + self._get_room_hierarchy, + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ) + + async def _get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" + + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(requested_room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, requested_room_id), + ) + + # If this is continuing a previous session, pull the persisted data. + if from_token: + self._expire_pagination_sessions() + + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, from_token + ) + if pagination_key not in self._pagination_sessions: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) + + # Load the previous state. + pagination_session = self._pagination_sessions[pagination_key] + room_queue = pagination_session.room_queue + processed_rooms = pagination_session.processed_rooms + else: + # The queue of rooms to process, the next room is last on the stack. + room_queue = [_RoomQueueEntry(requested_room_id, ())] + + # Rooms we have already processed. + processed_rooms = set() + + rooms_result: List[JsonDict] = [] + + # Cap the limit to a server-side maximum. + if limit is None: + limit = MAX_ROOMS + else: + limit = min(limit, MAX_ROOMS) + + # Iterate through the queue until we reach the limit or run out of + # rooms to include. + while room_queue and len(rooms_result) < limit: + queue_entry = room_queue.pop() + room_id = queue_entry.room_id + current_depth = queue_entry.depth + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + # A map of summaries for children rooms that might be returned over + # federation. The rationale for caching these and *maybe* using them + # is to prefer any information local to the homeserver before trusting + # data received over federation. + children_room_entries: Dict[str, JsonDict] = {} + # A set of room IDs which are children that did not have information + # returned over federation and are known to be inaccessible to the + # current server. We should not reach out over federation to try to + # summarise these rooms. + inaccessible_children: Set[str] = set() + + # If the room is known locally, summarise it! + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + suggested_only, + # TODO Handle max children. + max_children=None, + ) + + # Otherwise, attempt to use information for federation. + else: + # A previous call might have included information for this room. + # It can be used if either: + # + # 1. The room is not a space. + # 2. The maximum depth has been achieved (since no children + # information is needed). + if queue_entry.remote_room and ( + queue_entry.remote_room.get("room_type") != RoomTypes.SPACE + or (max_depth is not None and current_depth >= max_depth) + ): + room_entry = _RoomEntry( + queue_entry.room_id, queue_entry.remote_room + ) + + # If the above isn't true, attempt to fetch the room + # information over federation. + else: + ( + room_entry, + children_room_entries, + inaccessible_children, + ) = await self._summarize_remote_room_hierarchy( + queue_entry, + suggested_only, + ) + + # Ensure this room is accessible to the requester (and not just + # the homeserver). + if room_entry and not await self._is_remote_room_accessible( + requester, queue_entry.room_id, room_entry.room + ): + room_entry = None + + # This room has been processed and should be ignored if it appears + # elsewhere in the hierarchy. + processed_rooms.add(room_id) + + # There may or may not be a room entry based on whether it is + # inaccessible to the requesting user. + if room_entry: + # Add the room (including the stripped m.space.child events). + rooms_result.append(room_entry.as_json()) + + # If this room is not at the max-depth, check if there are any + # children to process. + if max_depth is None or current_depth < max_depth: + # The children get added in reverse order so that the next + # room to process, according to the ordering, is the last + # item in the list. + room_queue.extend( + _RoomQueueEntry( + ev["state_key"], + ev["content"]["via"], + current_depth + 1, + children_room_entries.get(ev["state_key"]), + ) + for ev in reversed(room_entry.children_state_events) + if ev["type"] == EventTypes.SpaceChild + and ev["state_key"] not in inaccessible_children + ) + + result: JsonDict = {"rooms": rooms_result} + + # If there's additional data, generate a pagination token (and persist state). + if room_queue: + next_batch = random_string(24) + result["next_batch"] = next_batch + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, next_batch + ) + self._pagination_sessions[pagination_key] = _PaginationSession( + self._clock.time_msec(), room_queue, processed_rooms + ) + + return result + + async def federation_space_summary( + self, + origin: str, + room_id: str, + suggested_only: bool, + max_rooms_per_space: Optional[int], + exclude_rooms: Iterable[str], + ) -> JsonDict: + """ + Implementation of the space summary Federation API + + Args: + origin: The server requesting the spaces summary. + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. Unlike the C-S API, this applies to the root room (room_id). + It is clipped to MAX_ROOMS_PER_SPACE. + + exclude_rooms: a list of rooms to skip over (presumably because the + calling server has already seen them). + + Returns: + summary dict to return + """ + # the queue of rooms to process + room_queue = deque((room_id,)) + + # the set of rooms that we should not walk further. Initialise it with the + # excluded-rooms list; we will add other rooms as we process them so that + # we do not loop. + processed_rooms: Set[str] = set(exclude_rooms) + + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] + + while room_queue and len(rooms_result) < MAX_ROOMS: + room_id = room_queue.popleft() + if room_id in processed_rooms: + # already done this room + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_rooms_per_space + ) + + processed_rooms.add(room_id) + + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children_state_events) + + # add any children to the queue + room_queue.extend( + edge_event["state_key"] + for edge_event in room_entry.children_state_events + ) + + return {"rooms": rooms_result, "events": events_result} + + async def get_federation_hierarchy( + self, + origin: str, + requested_room_id: str, + suggested_only: bool, + ): + """ + Implementation of the room hierarchy Federation API. + + This is similar to get_room_hierarchy, but does not recurse into the space. + It also considers whether anyone on the server may be able to access the + room, as opposed to whether a specific user can. + + Args: + origin: The server requesting the spaces summary. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: whether we should only return children with the "suggested" + flag set. + + Returns: + The JSON hierarchy dictionary. + """ + root_room_entry = await self._summarize_local_room( + None, origin, requested_room_id, suggested_only, max_children=None + ) + if root_room_entry is None: + # Room is inaccessible to the requesting server. + raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) + + children_rooms_result: List[JsonDict] = [] + inaccessible_children: List[str] = [] + + # Iterate through each child and potentially add it, but not its children, + # to the response. + for child_room in root_room_entry.children_state_events: + room_id = child_room.get("state_key") + assert isinstance(room_id, str) + # If the room is unknown, skip it. + if not await self._store.is_host_joined(room_id, self._server_name): + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_children=0 + ) + # If the room is accessible, include it in the results. + # + # Note that only the room summary (without information on children) + # is included in the summary. + if room_entry: + children_rooms_result.append(room_entry.room) + # Otherwise, note that the requesting server shouldn't bother + # trying to summarize this room - they do not have access to it. + else: + inaccessible_children.append(room_id) + + return { + # Include the requested room (including the stripped children events). + "room": root_room_entry.as_json(), + "children": children_rooms_result, + "inaccessible_children": inaccessible_children, + } + + async def _summarize_local_room( + self, + requester: Optional[str], + origin: Optional[str], + room_id: str, + suggested_only: bool, + max_children: Optional[int], + ) -> Optional["_RoomEntry"]: + """ + Generate a room entry and a list of event entries for a given room. + + Args: + requester: + The user requesting the summary, if it is a local request. None + if this is a federation request. + origin: + The server requesting the summary, if it is a federation request. + None if this is a local request. + room_id: The room ID to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + + Returns: + A room entry if the room should be returned. None, otherwise. + """ + if not await self._is_local_room_accessible(room_id, requester, origin): + return None + + room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) + + # If the room is not a space or the children don't matter, return just + # the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + return _RoomEntry(room_id, room_entry) + + # Otherwise, look for child rooms/spaces. + child_events = await self._get_child_events(room_id) + + if suggested_only: + # we only care about suggested children + child_events = filter(_is_suggested_child_event, child_events) + + if max_children is None or max_children > MAX_ROOMS_PER_SPACE: + max_children = MAX_ROOMS_PER_SPACE + + now = self._clock.time_msec() + events_result: List[JsonDict] = [] + for edge_event in itertools.islice(child_events, max_children): + events_result.append( + await self._event_serializer.serialize_event( + edge_event, + time_now=now, + event_format=format_event_for_client_v2, + ) + ) + + return _RoomEntry(room_id, room_entry, events_result) + + async def _summarize_remote_room( + self, + room: "_RoomQueueEntry", + suggested_only: bool, + max_children: Optional[int], + exclude_rooms: Iterable[str], + ) -> Iterable["_RoomEntry"]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + exclude_rooms: + Rooms IDs which do not need to be summarized. + + Returns: + An iterable of room entries. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + # we need to make the exclusion list json-serialisable + exclude_rooms = list(exclude_rooms) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + res = await self._federation_client.get_space_summary( + via, + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_children, + exclude_rooms=exclude_rooms, + ) + except Exception as e: + logger.warning( + "Unable to get summary of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue + + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results + + async def _summarize_remote_room_hierarchy( + self, room: "_RoomQueueEntry", suggested_only: bool + ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + + Returns: + A tuple of: + The room entry. + Partial room data return over federation. + A set of inaccessible children room IDs. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + ( + room_response, + children, + inaccessible_children, + ) = await self._federation_client.get_room_hierarchy( + via, + room_id, + suggested_only=suggested_only, + ) + except Exception as e: + logger.warning( + "Unable to get hierarchy of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return None, {}, set() + + # Map the children to their room ID. + children_by_room_id = { + c["room_id"]: c + for c in children + if "room_id" in c and isinstance(c["room_id"], str) + } + + return ( + _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), + children_by_room_id, + set(inaccessible_children), + ) + + async def _is_local_room_accessible( + self, room_id: str, requester: Optional[str], origin: Optional[str] = None + ) -> bool: + """ + Calculate whether the room should be shown to the requester. + + It should return true if: + + * The requester is joined or can join the room (per MSC3173). + * The origin server has any user that is joined or can join the room. + * The history visibility is set to world readable. + + Args: + room_id: The room ID to check accessibility of. + requester: + The user making the request, if it is a local request. + None if this is a federation request. + origin: + The server making the request, if it is a federation request. + None if this is a local request. + + Returns: + True if the room is accessible to the requesting user or server. + """ + state_ids = await self._store.get_current_state_ids(room_id) + + # If there's no state for the room, it isn't known. + if not state_ids: + # The user might have a pending invite for the room. + if requester and await self._store.get_invite_for_local_user_in_room( + requester, room_id + ): + return True + + logger.info("room %s is unknown, omitting from summary", room_id) + return False + + room_version = await self._store.get_room_version(room_id) + + # Include the room if it has join rules of public or knock. + join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) + if join_rules_event_id: + join_rules_event = await self._store.get_event(join_rules_event_id) + join_rule = join_rules_event.content.get("join_rule") + if join_rule == JoinRules.PUBLIC or ( + room_version.msc2403_knocking and join_rule == JoinRules.KNOCK + ): + return True + + # Include the room if it is peekable. + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + # Otherwise we need to check information specific to the user or server. + + # If we have an authenticated requesting user, check if they are a member + # of the room (or can join the room). + if requester: + member_event_id = state_ids.get((EventTypes.Member, requester), None) + + # If they're in the room they can see info on it. + if member_event_id: + member_event = await self._store.get_event(member_event_id) + if member_event.membership in (Membership.JOIN, Membership.INVITE): + return True + + # Otherwise, check if they should be allowed access via membership in a space. + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) + ) + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # If this is a request over federation, check if the host is in the room or + # has a user who could join the room. + elif origin: + if await self._event_auth_handler.check_host_in_room( + room_id, origin + ) or await self._store.is_host_invited(room_id, origin): + return True + + # Alternately, if the host has a user in any of the spaces specified + # for access, then the host can see this room (and should do filtering + # if the requester cannot see it). + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) + ) + for space_id in allowed_rooms: + if await self._event_auth_handler.check_host_in_room( + space_id, origin + ): + return True + + logger.info( + "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", + room_id, + requester or origin, + ) + return False + + async def _is_remote_room_accessible( + self, requester: str, room_id: str, room: JsonDict + ) -> bool: + """ + Calculate whether the room received over federation should be shown to the requester. + + It should return true if: + + * The requester is joined or can join the room (per MSC3173). + * The history visibility is set to world readable. + + Note that the local server is not in the requested room (which is why the + remote call was made in the first place), but the user could have access + due to an invite, etc. + + Args: + requester: The user requesting the summary. + room_id: The room ID returned over federation. + room: The summary of the room returned over federation. + + Returns: + True if the room is accessible to the requesting user. + """ + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + if ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ): + return True + + # Check if the user is a member of any of the allowed spaces + # from the response. + allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + if allowed_rooms and isinstance(allowed_rooms, list): + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # Finally, check locally if we can access the room. The user might + # already be in the room (if it was a child room), or there might be a + # pending invite, etc. + return await self._is_local_room_accessible(room_id, requester) + + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: + """ + Generate en entry summarising a single room. + + Args: + room_id: The room ID to summarize. + for_federation: True if this is a summary requested over federation + (which includes additional fields). + + Returns: + The JSON dictionary for the room. + """ + stats = await self._store.get_room_with_stats(room_id) + + # currently this should be impossible because we call + # _is_local_room_accessible on the room before we get here, so + # there should always be an entry + assert stats is not None, "unable to retrieve stats for %s" % (room_id,) + + current_state_ids = await self._store.get_current_state_ids(room_id) + create_event = await self._store.get_event( + current_state_ids[(EventTypes.Create, "")] + ) + + entry = { + "room_id": stats["room_id"], + "name": stats["name"], + "topic": stats["topic"], + "canonical_alias": stats["canonical_alias"], + "num_joined_members": stats["joined_members"], + "avatar_url": stats["avatar"], + "join_rules": stats["join_rules"], + "world_readable": ( + stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + ), + "guest_can_join": stats["guest_access"] == "can_join", + "creation_ts": create_event.origin_server_ts, + "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), + } + + # Federation requests need to provide additional information so the + # requested server is able to filter the response appropriately. + if for_federation: + room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join( + current_state_ids + ) + ) + if allowed_rooms: + entry["allowed_room_ids"] = allowed_rooms + # TODO Remove this key once the API is stable. + entry["allowed_spaces"] = allowed_rooms + + # Filter out Nones – rather omit the field altogether + room_entry = {k: v for k, v in entry.items() if v is not None} + + return room_entry + + async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + """ + Get the child events for a given room. + + The returned results are sorted for stability. + + Args: + room_id: The room id to get the children of. + + Returns: + An iterable of sorted child events. + """ + + # look for child rooms/spaces. + current_state_ids = await self._store.get_current_state_ids(room_id) + + events = await self._store.get_events_as_list( + [ + event_id + for key, event_id in current_state_ids.items() + if key[0] == EventTypes.SpaceChild + ] + ) + + # filter out any events without a "via" (which implies it has been redacted), + # and order to ensure we return stable results. + return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) + + async def get_room_summary( + self, + requester: Optional[str], + room_id: str, + remote_room_hosts: Optional[List[str]] = None, + ) -> JsonDict: + """ + Implementation of the room summary C-S API from MSC3266 + + Args: + requester: user id of the user making this request, will be None + for unauthenticated requests + + room_id: room id to summarise. + + remote_room_hosts: a list of homeservers to try fetching data through + if we don't know it ourselves + + Returns: + summary dict to return + """ + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + # Suggested-only doesn't matter since no children are requested. + suggested_only=False, + max_children=0, + ) + + if not room_entry: + raise NotFoundError("Room not found or is not accessible") + + room_summary = room_entry.room + + # If there was a requester, add their membership. + if requester: + ( + membership, + _, + ) = await self._store.get_local_current_membership_for_user_in_room( + requester, room_id + ) + + room_summary["membership"] = membership or "leave" + else: + # TODO federation API, descoped from initial unstable implementation + # as MSC needs more maturing on that side. + raise SynapseError(400, "Federation is not currently supported.") + + return room_summary + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomQueueEntry: + # The room ID of this entry. + room_id: str + # The server to query if the room is not known locally. + via: Sequence[str] + # The minimum number of hops necessary to get to this room (compared to the + # originally requested room). + depth: int = 0 + # The room summary for this room returned via federation. This will only be + # used if the room is not known locally (and is not a space). + remote_room: Optional[JsonDict] = None + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children_state_events: Sequence[JsonDict] = () + + def as_json(self) -> JsonDict: + """ + Returns a JSON dictionary suitable for the room hierarchy endpoint. + + It returns the room summary including the stripped m.space.child events + as a sub-key. + """ + result = dict(self.room) + result["children_state"] = self.children_state_events + return result + + +def _has_valid_via(e: EventBase) -> bool: + via = e.content.get("via") + if not via or not isinstance(via, Sequence): + return False + for v in via: + if not isinstance(v, str): + logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) + return False + return True + + +def _is_suggested_child_event(edge_event: EventBase) -> bool: + suggested = edge_event.content.get("suggested") + if isinstance(suggested, bool) and suggested: + return True + logger.debug("Ignorning not-suggested child %s", edge_event.state_key) + return False + + +# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. +_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") + + +def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: + """ + Generate a value for comparing two child events for ordering. + + The rules for ordering are supposed to be: + + 1. The 'order' key, if it is valid. + 2. The 'origin_server_ts' of the 'm.room.create' event. + 3. The 'room_id'. + + But we skip step 2 since we may not have any state from the room. + + Args: + child: The event for generating a comparison key. + + Returns: + The comparison key as a tuple of: + False if the ordering is valid. + The ordering field. + The room ID. + """ + order = child.content.get("order") + # If order is not a string or doesn't meet the requirements, ignore it. + if not isinstance(order, str): + order = None + elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): + order = None + + # Items without an order come last. + return (order is None, order, child.room_id) diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index e9f6aef06f01..dda9659c11c2 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -16,7 +16,12 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING +from io import BytesIO +from typing import TYPE_CHECKING, Optional + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IReactorTCP +from twisted.mail.smtp import ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -26,19 +31,75 @@ logger = logging.getLogger(__name__) +async def _sendmail( + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + username: Optional[bytes] = None, + password: Optional[bytes] = None, + require_auth: bool = False, + require_tls: bool = False, + tls_hostname: Optional[str] = None, +) -> None: + """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests + + Params: + reactor: reactor to use to make the outbound connection + smtphost: hostname to connect to + smtpport: port to connect to + from_addr: "From" address for email + to_addr: "To" address for email + msg_bytes: Message content + username: username to authenticate with, if auth is enabled + password: password to give when authenticating + require_auth: if auth is not offered, fail the request + require_tls: if TLS is not offered, fail the reqest + tls_hostname: TLS hostname to check for. None to disable TLS. + """ + msg = BytesIO(msg_bytes) + + d: "Deferred[object]" = Deferred() + + factory = ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + hostname=tls_hostname, + ) + + # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong + reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] + + await make_deferred_yieldable(d) + + class SendEmailHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self._sendmail = hs.get_sendmail() self._reactor = hs.get_reactor() self._from = hs.config.email.email_notif_from self._smtp_host = hs.config.email.email_smtp_host self._smtp_port = hs.config.email.email_smtp_port - self._smtp_user = hs.config.email.email_smtp_user - self._smtp_pass = hs.config.email.email_smtp_pass + + user = hs.config.email.email_smtp_user + self._smtp_user = user.encode("utf-8") if user is not None else None + passwd = hs.config.email.email_smtp_pass + self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None self._require_transport_security = hs.config.email.require_transport_security + self._enable_tls = hs.config.email.enable_smtp_tls + + self._sendmail = _sendmail async def send_email( self, @@ -82,17 +143,16 @@ async def send_email( logger.info("Sending email to %s" % email_address) - await make_deferred_yieldable( - self._sendmail( - self._smtp_host, - raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - reactor=self._reactor, - port=self._smtp_port, - requireAuthentication=self._smtp_user is not None, - username=self._smtp_user, - password=self._smtp_pass, - requireTransportSecurity=self._require_transport_security, - ) + await self._sendmail( + self._reactor, + self._smtp_host, + self._smtp_port, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + username=self._smtp_user, + password=self._smtp_pass, + require_auth=self._smtp_user is not None, + require_tls=self._require_transport_security, + tls_hostname=self._smtp_host if self._enable_tls else None, ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py deleted file mode 100644 index 5f7d4602bd8d..000000000000 --- a/synapse/handlers/space_summary.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -import re -from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple - -import attr - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RoomTypes, -) -from synapse.events import EventBase -from synapse.events.utils import format_event_for_client_v2 -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# number of rooms to return. We'll stop once we hit this limit. -MAX_ROOMS = 50 - -# max number of events to return per room. -MAX_ROOMS_PER_SPACE = 50 - -# max number of federation servers to hit per room -MAX_SERVERS_PER_SPACE = 3 - - -class SpaceSummaryHandler: - def __init__(self, hs: "HomeServer"): - self._clock = hs.get_clock() - self._auth = hs.get_auth() - self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() - self._event_serializer = hs.get_event_client_serializer() - self._server_name = hs.hostname - self._federation_client = hs.get_federation_client() - - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # first of all, check that the user is in the room in question (or it's - # world-readable) - await self._auth.check_user_in_room_or_world_readable(room_id, requester) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else None - - if is_in_room: - room, events = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - if room: - rooms_result.append(room) - else: - fed_rooms, fed_events = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - room_ids = set() - events = [] - for room in fed_rooms: - fed_room_id = room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - # The room should only be included in the summary if: - # a. the user is in the room; - # b. the room is world readable; or - # c. the user could join the room, e.g. the join rules - # are set to public or the user is in a space that - # has been granted access to the room. - # - # Note that we know the user is not in the root room (which is - # why the remote call was made in the first place), but the user - # could be in one of the children rooms and we just didn't know - # about the link. - - # The API doesn't return the room version so assume that a - # join rule of knock is valid. - include_room = ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) - or room.get("world_readable") is True - ) - - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_spaces") - if ( - not include_room - and allowed_rooms - and isinstance(allowed_rooms, list) - ): - include_room = await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ) - - # Finally, if this isn't the requested room, check ourselves - # if we can access the room. - if not include_room and fed_room_id != queue_entry.room_id: - include_room = await self._is_room_accessible( - fed_room_id, requester, None - ) - - # The user can see the room, include it! - if include_room: - rooms_result.append(room) - room_ids.add(fed_room_id) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - for event in fed_events: - if event.get("room_id") in room_ids: - events.append(event) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room.get("room_id") for room in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - # Before returning to the client, remove the allowed_spaces key for any - # rooms. - for room in rooms_result: - room.pop("allowed_spaces", None) - - return {"rooms": rooms_result, "events": events_result} - - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - room, events = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room: - rooms_result.append(room) - events_result.extend(events) - - # add any children to the queue - room_queue.extend(edge_event["state_key"] for edge_event in events) - - return {"rooms": rooms_result, "events": events_result} - - async def _summarize_local_room( - self, - requester: Optional[str], - origin: Optional[str], - room_id: str, - suggested_only: bool, - max_children: Optional[int], - ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: - """ - Generate a room entry and a list of event entries for a given room. - - Args: - requester: - The user requesting the summary, if it is a local request. None - if this is a federation request. - origin: - The server requesting the summary, if it is a federation request. - None if this is a local request. - room_id: The room ID to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - - Returns: - A tuple of: - The room information, if the room should be returned to the - user. None, otherwise. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. - """ - if not await self._is_room_accessible(room_id, requester, origin): - return None, () - - room_entry = await self._build_room_entry(room_id) - - # If the room is not a space, return just the room information. - if room_entry.get("room_type") != RoomTypes.SPACE: - return room_entry, () - - # Otherwise, look for child rooms/spaces. - child_events = await self._get_child_events(room_id) - - if suggested_only: - # we only care about suggested children - child_events = filter(_is_suggested_child_event, child_events) - - if max_children is None or max_children > MAX_ROOMS_PER_SPACE: - max_children = MAX_ROOMS_PER_SPACE - - now = self._clock.time_msec() - events_result: List[JsonDict] = [] - for edge_event in itertools.islice(child_events, max_children): - events_result.append( - await self._event_serializer.serialize_event( - edge_event, - time_now=now, - event_format=format_event_for_client_v2, - ) - ) - - return room_entry, events_result - - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - A tuple of: - An iterable of rooms. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return (), () - - return res.rooms, tuple( - ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild - ) - - async def _is_room_accessible( - self, room_id: str, requester: Optional[str], origin: Optional[str] - ) -> bool: - """ - Calculate whether the room should be shown in the spaces summary. - - It should be included if: - - * The requester is joined or can join the room (per MSC3173). - * The origin server has any user that is joined or can join the room. - * The history visibility is set to world readable. - - Args: - room_id: The room ID to summarize. - requester: - The user requesting the summary, if it is a local request. None - if this is a federation request. - origin: - The server requesting the summary, if it is a federation request. - None if this is a local request. - - Returns: - True if the room should be included in the spaces summary. - """ - state_ids = await self._store.get_current_state_ids(room_id) - - # If there's no state for the room, it isn't known. - if not state_ids: - # The user might have a pending invite for the room. - if requester and await self._store.get_invite_for_local_user_in_room( - requester, room_id - ): - return True - - logger.info("room %s is unknown, omitting from summary", room_id) - return False - - room_version = await self._store.get_room_version(room_id) - - # Include the room if it has join rules of public or knock. - join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) - if join_rules_event_id: - join_rules_event = await self._store.get_event(join_rules_event_id) - join_rule = join_rules_event.content.get("join_rule") - if join_rule == JoinRules.PUBLIC or ( - room_version.msc2403_knocking and join_rule == JoinRules.KNOCK - ): - return True - - # Include the room if it is peekable. - hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) - if hist_vis_event_id: - hist_vis_ev = await self._store.get_event(hist_vis_event_id) - hist_vis = hist_vis_ev.content.get("history_visibility") - if hist_vis == HistoryVisibility.WORLD_READABLE: - return True - - # Otherwise we need to check information specific to the user or server. - - # If we have an authenticated requesting user, check if they are a member - # of the room (or can join the room). - if requester: - member_event_id = state_ids.get((EventTypes.Member, requester), None) - - # If they're in the room they can see info on it. - if member_event_id: - member_event = await self._store.get_event(member_event_id) - if member_event.membership in (Membership.JOIN, Membership.INVITE): - return True - - # Otherwise, check if they should be allowed access via membership in a space. - if await self._event_auth_handler.has_restricted_join_rules( - state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join(state_ids) - ) - if await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ): - return True - - # If this is a request over federation, check if the host is in the room or - # has a user who could join the room. - elif origin: - if await self._event_auth_handler.check_host_in_room( - room_id, origin - ) or await self._store.is_host_invited(room_id, origin): - return True - - # Alternately, if the host has a user in any of the spaces specified - # for access, then the host can see this room (and should do filtering - # if the requester cannot see it). - if await self._event_auth_handler.has_restricted_join_rules( - state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join(state_ids) - ) - for space_id in allowed_rooms: - if await self._event_auth_handler.check_host_in_room( - space_id, origin - ): - return True - - logger.info( - "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", - room_id, - requester or origin, - ) - return False - - async def _build_room_entry(self, room_id: str) -> JsonDict: - """Generate en entry suitable for the 'rooms' list in the summary response""" - stats = await self._store.get_room_with_stats(room_id) - - # currently this should be impossible because we call - # check_user_in_room_or_world_readable on the room before we get here, so - # there should always be an entry - assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - - current_state_ids = await self._store.get_current_state_ids(room_id) - create_event = await self._store.get_event( - current_state_ids[(EventTypes.Create, "")] - ) - - room_version = await self._store.get_room_version(room_id) - allowed_rooms = None - if await self._event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version - ): - allowed_rooms = await self._event_auth_handler.get_rooms_that_allow_join( - current_state_ids - ) - - entry = { - "room_id": stats["room_id"], - "name": stats["name"], - "topic": stats["topic"], - "canonical_alias": stats["canonical_alias"], - "num_joined_members": stats["joined_members"], - "avatar_url": stats["avatar"], - "join_rules": stats["join_rules"], - "world_readable": ( - stats["history_visibility"] == HistoryVisibility.WORLD_READABLE - ), - "guest_can_join": stats["guest_access"] == "can_join", - "creation_ts": create_event.origin_server_ts, - "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), - "allowed_spaces": allowed_rooms, - } - - # Filter out Nones – rather omit the field altogether - room_entry = {k: v for k, v in entry.items() if v is not None} - - return room_entry - - async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: - """ - Get the child events for a given room. - - The returned results are sorted for stability. - - Args: - room_id: The room id to get the children of. - - Returns: - An iterable of sorted child events. - """ - - # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) - - events = await self._store.get_events_as_list( - [ - event_id - for key, event_id in current_state_ids.items() - if key[0] == EventTypes.SpaceChild - ] - ) - - # filter out any events without a "via" (which implies it has been redacted), - # and order to ensure we return stable results. - return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) - - -@attr.s(frozen=True, slots=True) -class _RoomQueueEntry: - room_id = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) - - -def _has_valid_via(e: EventBase) -> bool: - via = e.content.get("via") - if not via or not isinstance(via, Sequence): - return False - for v in via: - if not isinstance(v, str): - logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) - return False - return True - - -def _is_suggested_child_event(edge_event: EventBase) -> bool: - suggested = edge_event.content.get("suggested") - if isinstance(suggested, bool) and suggested: - return True - logger.debug("Ignorning not-suggested child %s", edge_event.state_key) - return False - - -# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. -_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") - - -def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: - """ - Generate a value for comparing two child events for ordering. - - The rules for ordering are supposed to be: - - 1. The 'order' key, if it is valid. - 2. The 'origin_server_ts' of the 'm.room.create' event. - 3. The 'room_id'. - - But we skip step 2 since we may not have any state from the room. - - Args: - child: The event for generating a comparison key. - - Returns: - The comparison key as a tuple of: - False if the ordering is valid. - The ordering field. - The room ID. - """ - order = child.content.get("order") - # If order is not a string or doesn't meet the requirements, ignore it. - if not isinstance(order, str): - order = None - elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): - order = None - - # Items without an order come last. - return (order is None, order, child.room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f30bfcc93cf2..590642f510fe 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,14 +269,22 @@ def __init__(self, hs: "HomeServer"): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( - hs.get_clock(), "sync" - ) self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() self.state_store = self.storage.state + # TODO: flush cache entries on subsequent sync request. + # Once we get the next /sync request (ie, one with the same access token + # that sets 'since' to 'next_batch'), we know that device won't need a + # cached result any more, and we could flush the entry from the cache to save + # memory. + self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( + hs.get_clock(), + "sync", + timeout_ms=hs.config.caches.sync_response_cache_duration, + ) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache: ExpiringCache[ Tuple[str, Optional[str]], LruCache[str, str] diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0cb651a40009..a97c448595e9 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -335,7 +335,8 @@ async def _recv_edu(self, origin: str, content: JsonDict) -> None: ) if not is_in_room: logger.info( - "Ignoring typing update from %s as we're not in the room", + "Ignoring typing update for room %r from server %s as we're not in the room", + room_id, origin, ) return diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 17e1c5abb13d..c577142268c5 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging +from typing import Optional +import attr from zope.interface import implementer from twisted.internet import defer, protocol @@ -21,7 +24,6 @@ from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http -from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -30,6 +32,22 @@ class ProxyConnectError(ConnectError): pass +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IStreamClientEndpoint) class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy @@ -46,7 +64,7 @@ class HTTPConnectProxyEndpoint: proxy_endpoint: the endpoint to use to connect to the proxy host: hostname that we want to CONNECT to port: port that we want to connect to - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -55,20 +73,20 @@ def __init__( proxy_endpoint: IStreamClientEndpoint, host: bytes, port: int, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port - self._headers = headers + self._proxy_creds = proxy_creds def __repr__(self): return "" % (self._proxy_endpoint,) def connect(self, protocolFactory: ClientFactory): f = HTTPProxiedClientFactory( - self._host, self._port, protocolFactory, self._headers + self._host, self._port, protocolFactory, self._proxy_creds ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the @@ -87,7 +105,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): dst_host: hostname that we want to CONNECT to dst_port: port that we want to connect to wrapped_factory: The original Factory - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -95,12 +113,12 @@ def __init__( dst_host: bytes, dst_port: int, wrapped_factory: ClientFactory, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory - self.headers = headers + self.proxy_creds = proxy_creds self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -114,7 +132,7 @@ def buildProtocol(self, addr): self.dst_port, wrapped_protocol, self.on_connection, - self.headers, + self.proxy_creds, ) def clientConnectionFailed(self, connector, reason): @@ -145,7 +163,7 @@ class HTTPConnectProtocol(protocol.Protocol): connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -154,16 +172,16 @@ def __init__( port: int, wrapped_protocol: Protocol, connected_deferred: defer.Deferred, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.headers = headers + self.proxy_creds = proxy_creds self.http_setup_client = HTTPConnectSetupClient( - self.host, self.port, self.headers + self.host, self.port, self.proxy_creds ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) @@ -205,30 +223,38 @@ class HTTPConnectSetupClient(http.HTTPClient): Args: host: The hostname to send in the CONNECT message port: The port to send in the CONNECT message - headers: Extra headers to send with the CONNECT message + proxy_creds: credentials to authenticate at proxy """ - def __init__(self, host: bytes, port: int, headers: Headers): + def __init__( + self, + host: bytes, + port: int, + proxy_creds: Optional[ProxyCredentials], + ): self.host = host self.port = port - self.headers = headers + self.proxy_creds = proxy_creds self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) - # Send any additional specified headers - for name, values in self.headers.getAllRawHeaders(): - for value in values: - self.sendHeader(name, value) + # Determine whether we need to set Proxy-Authorization headers + if self.proxy_creds: + # Set a Proxy-Authorization header + self.sendHeader( + b"Proxy-Authorization", + self.proxy_creds.as_proxy_authorization_value(), + ) self.endHeaders() def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": - raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") def handleEndHeaders(self): logger.debug("End Headers") diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index c16b7f10e645..1238bfd28726 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,6 +14,10 @@ import logging import urllib.parse from typing import Any, Generator, List, Optional +from urllib.request import ( # type: ignore[attr-defined] + getproxies_environment, + proxy_bypass_environment, +) from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer @@ -30,9 +34,12 @@ from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse from synapse.crypto.context_factory import FederationPolicyForHTTPS -from synapse.http.client import BlacklistingAgentWrapper +from synapse.http import proxyagent +from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import ISynapseReactor from synapse.util import Clock @@ -57,6 +64,14 @@ class MatrixFederationAgent: user_agent: The user agent header to use for federation requests. + ip_whitelist: Allowed IP addresses. + + ip_blacklist: Disallowed IP addresses. + + proxy_reactor: twisted reactor to use for connections to the proxy server + reactor might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. + _srv_resolver: SrvResolver implementation to use for looking up SRV records. None to use a default implementation. @@ -71,11 +86,18 @@ def __init__( reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, + ip_whitelist: IPSet, ip_blacklist: IPSet, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, ): - self._reactor = reactor + # proxy_reactor is not blacklisted + proxy_reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist) + self._clock = Clock(reactor) self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False @@ -83,24 +105,27 @@ def __init__( self._pool.cachedConnectionTimeout = 2 * 60 self._agent = Agent.usingEndpointFactory( - self._reactor, + reactor, MatrixHostnameEndpointFactory( - reactor, tls_client_options_factory, _srv_resolver + reactor, + proxy_reactor, + tls_client_options_factory, + _srv_resolver, ), pool=self._pool, ) self.user_agent = user_agent if _well_known_resolver is None: - # Note that the name resolver has already been wrapped in a - # IPBlacklistingResolver by MatrixFederationHttpClient. _well_known_resolver = WellKnownResolver( - self._reactor, + reactor, agent=BlacklistingAgentWrapper( - Agent( - self._reactor, + ProxyAgent( + reactor, + proxy_reactor, pool=self._pool, contextFactory=tls_client_options_factory, + use_proxy=True, ), ip_blacklist=ip_blacklist, ), @@ -200,10 +225,12 @@ class MatrixHostnameEndpointFactory: def __init__( self, reactor: IReactorCore, + proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: Optional[SrvResolver], ): self._reactor = reactor + self._proxy_reactor = proxy_reactor self._tls_client_options_factory = tls_client_options_factory if srv_resolver is None: @@ -211,9 +238,10 @@ def __init__( self._srv_resolver = srv_resolver - def endpointForURI(self, parsed_uri): + def endpointForURI(self, parsed_uri: URI): return MatrixHostnameEndpoint( self._reactor, + self._proxy_reactor, self._tls_client_options_factory, self._srv_resolver, parsed_uri, @@ -227,23 +255,45 @@ class MatrixHostnameEndpoint: Args: reactor: twisted reactor to use for underlying requests + proxy_reactor: twisted reactor to use for connections to the proxy server. + 'reactor' might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. srv_resolver: The SRV resolver to use parsed_uri: The parsed URI that we're wanting to connect to. + + Raises: + ValueError if the environment variables contain an invalid proxy specification. + RuntimeError if no tls_options_factory is given for a https connection """ def __init__( self, reactor: IReactorCore, + proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: SrvResolver, parsed_uri: URI, ): self._reactor = reactor - self._parsed_uri = parsed_uri + # http_proxy is not needed because federation is always over TLS + proxies = getproxies_environment() + https_proxy = proxies["https"].encode() if "https" in proxies else None + self.no_proxy = proxies["no"] if "no" in proxies else None + + # endpoint and credentials to use to connect to the outbound https proxy, if any. + ( + self._https_proxy_endpoint, + self._https_proxy_creds, + ) = proxyagent.http_proxy_endpoint( + https_proxy, + proxy_reactor, + tls_client_options_factory, + ) + # set up the TLS connection params # # XXX disabling TLS is really only supported here for the benefit of the @@ -273,9 +323,33 @@ async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: host = server.host port = server.port + should_skip_proxy = False + if self.no_proxy is not None: + should_skip_proxy = proxy_bypass_environment( + host.decode(), + proxies={"no": self.no_proxy}, + ) + + endpoint: IStreamClientEndpoint try: - logger.debug("Connecting to %s:%i", host.decode("ascii"), port) - endpoint = HostnameEndpoint(self._reactor, host, port) + if self._https_proxy_endpoint and not should_skip_proxy: + logger.debug( + "Connecting to %s:%i via %s", + host.decode("ascii"), + port, + self._https_proxy_endpoint, + ) + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self._https_proxy_endpoint, + host, + port, + proxy_creds=self._https_proxy_creds, + ) + else: + logger.debug("Connecting to %s:%i", host.decode("ascii"), port) + # not using a proxy + endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) result = await make_deferred_yieldable( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2efa15bf0470..2e9898997c4c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -59,7 +59,6 @@ from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, - BlacklistingReactorWrapper, BodyExceededMaxSize, ByteWriteable, encode_query_args, @@ -69,7 +68,7 @@ from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags -from synapse.types import ISynapseReactor, JsonDict +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -325,13 +324,7 @@ def __init__(self, hs, tls_client_options_factory): self.signing_key = hs.signing_key self.server_name = hs.hostname - # We need to use a DNS resolver which filters out blacklisted IP - # addresses, to prevent DNS rebinding. - self.reactor: ISynapseReactor = BlacklistingReactorWrapper( - hs.get_reactor(), - hs.config.federation_ip_range_whitelist, - hs.config.federation_ip_range_blacklist, - ) + self.reactor = hs.get_reactor() user_agent = hs.version_string if hs.config.user_agent_suffix: @@ -342,6 +335,7 @@ def __init__(self, hs, tls_client_options_factory): self.reactor, tls_client_options_factory, user_agent, + hs.config.federation_ip_range_whitelist, hs.config.federation_ip_range_blacklist, ) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 19e987f11877..a3f31452d0cc 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 import logging import re from typing import Any, Dict, Optional, Tuple @@ -21,7 +20,6 @@ proxy_bypass_environment, ) -import attr from zope.interface import implementer from twisted.internet import defer @@ -38,7 +36,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS -from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -46,22 +44,6 @@ _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") -@attr.s -class ProxyCredentials: - username_password = attr.ib(type=bytes) - - def as_proxy_authorization_value(self) -> bytes: - """ - Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). - - Returns: - A transformation of the authentication string the encoded value for - a Proxy-Authorization header. - """ - # Encode as base64 and prepend the authorization type - return b"Basic " + base64.encodebytes(self.username_password) - - @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -95,6 +77,7 @@ class ProxyAgent(_AgentBase): Raises: ValueError if use_proxy is set and the environment variables contain an invalid proxy specification. + RuntimeError if no tls_options_factory is given for a https connection """ def __init__( @@ -131,11 +114,11 @@ def __init__( https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint( + self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint( http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) - self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint( + self.https_proxy_endpoint, self.https_proxy_creds = http_proxy_endpoint( https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) @@ -224,22 +207,12 @@ def request( and self.https_proxy_endpoint and not should_skip_proxy ): - connect_headers = Headers() - - # Determine whether we need to set Proxy-Authorization headers - if self.https_proxy_creds: - # Set a Proxy-Authorization header - connect_headers.addRawHeader( - b"Proxy-Authorization", - self.https_proxy_creds.as_proxy_authorization_value(), - ) - endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, - headers=connect_headers, + self.https_proxy_creds, ) else: # not using a proxy @@ -268,10 +241,10 @@ def request( ) -def _http_proxy_endpoint( +def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, - tls_options_factory: IPolicyForHTTPS, + tls_options_factory: Optional[IPolicyForHTTPS], **kwargs, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy @@ -294,6 +267,7 @@ def _http_proxy_endpoint( Raise: ValueError if proxy has no hostname or unsupported scheme. + RuntimeError if no tls_options_factory is given for a https connection """ if proxy is None: return None, None @@ -305,8 +279,13 @@ def _http_proxy_endpoint( proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) if scheme == b"https": - tls_options = tls_options_factory.creatorForNetloc(host, port) - proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + if tls_options_factory: + tls_options = tls_options_factory.creatorForNetloc(host, port) + proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + else: + raise RuntimeError( + f"No TLS options for a https connection via proxy {proxy!s}" + ) return proxy_endpoint, credentials diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 732a1e6aeb88..a12fa30bfdd2 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,16 +14,28 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Iterable, List, Mapping, Optional, Sequence, overload +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + overload, +) from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -663,3 +675,45 @@ def register(self, http_server): else: raise NotImplementedError("RestServlet must register something.") + + +class ResolveRoomIdMixin: + def __init__(self, hs: "HomeServer"): + self.room_member_handler = hs.get_room_member_handler() + + async def resolve_room_id( + self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None + ) -> Tuple[str, Optional[List[str]]]: + """ + Resolve a room identifier to a room ID, if necessary. + + This also performanes checks to ensure the room ID is of the proper form. + + Args: + room_identifier: The room ID or alias. + remote_room_hosts: The potential remote room hosts to use. + + Returns: + The resolved room ID. + + Raises: + SynapseError if the room ID is of the wrong form. + """ + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + ( + room_id, + remote_room_hosts, + ) = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id, remote_room_hosts diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 473812b8e295..2d2ed229e208 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,7 +45,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -91,6 +91,7 @@ def __init__(self, hs: "HomeServer", auth_handler): self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() self._send_email_handler = hs.get_send_email_handler() + self.custom_template_dir = hs.config.server.custom_template_directory try: app_name = self._hs.config.email_app_name @@ -174,6 +175,16 @@ def email_app_name(self) -> str: """The application name configured in the homeserver's configuration.""" return self._hs.config.email.email_app_name + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get user info by user_id + + Args: + user_id: Fully qualified user id. + Returns: + UserInfo object if a user was found, otherwise None + """ + return await self._store.get_userinfo_by_id(user_id) + async def get_user_by_req( self, req: SynapseRequest, @@ -593,10 +604,15 @@ def looping_background_call( msec: float, *args, desc: Optional[str] = None, + run_on_all_instances: bool = False, **kwargs, ): """Wraps a function as a background process and calls it repeatedly. + NOTE: Will only run on the instance that is configured to run + background processes (which is the main process by default), unless + `run_on_all_workers` is set. + Waits `msec` initially before calling `f` for the first time. Args: @@ -607,12 +623,14 @@ def looping_background_call( msec: How long to wait between calls in milliseconds. *args: Positional arguments to pass to function. desc: The background task's description. Default to the function's name. + run_on_all_instances: Whether to run this on all instances, rather + than just the instance configured to run background tasks. **kwargs: Key arguments to pass to function. """ if desc is None: desc = f.__name__ - if self._hs.config.run_background_tasks: + if self._hs.config.run_background_tasks or run_on_all_instances: self._clock.looping_call( run_as_background_process, msec, @@ -667,7 +685,10 @@ def read_templates( A list containing the loaded templates, with the orders matching the one of the filenames parameter. """ - return self._hs.config.read_templates(filenames, custom_template_directory) + return self._hs.config.read_templates( + filenames, + (td for td in (self.custom_template_dir, custom_template_directory) if td), + ) class PublicRoomListManager: diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py deleted file mode 100644 index 8cc6de3f4698..000000000000 --- a/synapse/replication/slave/storage/room.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.tcp.streams import PublicRoomsStream -from synapse.storage.database import DatabasePool -from synapse.storage.databases.main.room import RoomWorkerStore - -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - - -class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._public_room_id_gen = SlavedIdTracker( - db_conn, "public_room_list_stream", "stream_id" - ) - - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(instance_name, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 4c0023c68aee..f41eabd85e58 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -32,7 +32,6 @@ GroupServerStream, PresenceFederationStream, PresenceStream, - PublicRoomsStream, PushersStream, PushRulesStream, ReceiptsStream, @@ -57,7 +56,6 @@ PushRulesStream, PushersStream, CachesStream, - PublicRoomsStream, DeviceListsStream, ToDeviceStream, FederationStream, @@ -79,7 +77,6 @@ "PushRulesStream", "PushersStream", "CachesStream", - "PublicRoomsStream", "DeviceListsStream", "ToDeviceStream", "TagAccountDataStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 3716c41bea7b..9b905aba9dbb 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -447,31 +447,6 @@ def __init__(self, hs): ) -class PublicRoomsStream(Stream): - """The public rooms list changed""" - - PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), - ) - - NAME = "public_rooms" - ROW_TYPE = PublicRoomsStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_current_public_room_stream_id), - store.get_all_new_public_rooms, - ) - - class DeviceListsStream(Stream): """Either a user has updated their devices or a remote server needs to be told about a device update. diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index d29f2fea5ed3..3adc57612435 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -14,39 +14,36 @@ # limitations under the License. from synapse.http.server import JsonResource from synapse.rest import admin -from synapse.rest.client import versions -from synapse.rest.client.v1 import ( - directory, - events, - initial_sync, - login as v1_login, - logout, - presence, - profile, - push_rule, - pusher, - room, - voip, -) -from synapse.rest.client.v2_alpha import ( +from synapse.rest.client import ( account, account_data, account_validity, auth, capabilities, devices, + directory, + events, filter, groups, + initial_sync, keys, knock, + login as v1_login, + logout, notifications, openid, password_policy, + presence, + profile, + push_rule, + pusher, read_marker, receipts, register, relations, report_event, + room, + room_batch, room_keys, room_upgrade_rest_servlet, sendtodevice, @@ -56,6 +53,8 @@ thirdparty, tokenrefresh, user_directory, + versions, + voip, ) @@ -84,7 +83,6 @@ def register_servlets(client_resource, hs): # Partially deprecated in r0 events.register_servlets(hs, client_resource) - # "v1" + "r0" room.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) @@ -94,8 +92,6 @@ def register_servlets(client_resource, hs): pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) logout.register_servlets(hs, client_resource) - - # "v2" sync.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) @@ -117,6 +113,7 @@ def register_servlets(client_resource, hs): user_directory.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) + room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index abf749b001ab..d5862a4da436 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet +from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, @@ -60,7 +61,6 @@ SearchUsersRestServlet, ShadowBanRestServlet, UserAdminServlet, - UserMediaRestServlet, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -224,7 +224,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) - UserMediaRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) @@ -241,6 +240,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ForwardExtremitiesRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) + UsernameAvailableRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 0a19a333d7f7..8ce443049e23 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -18,14 +18,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, ) -from synapse.types import JsonDict +from synapse.storage.databases.main.media_repository import MediaSortOrder +from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -259,7 +260,9 @@ async def on_DELETE( logging.info("Deleting local media by ID: %s", media_id) - deleted_media, total = await self.media_repository.delete_local_media(media_id) + deleted_media, total = await self.media_repository.delete_local_media_ids( + [media_id] + ) return 200, {"deleted_media": deleted_media, "total": total} @@ -312,6 +315,165 @@ async def on_POST( return 200, {"deleted_media": deleted_media, "total": total} +class UserMediaRestServlet(RestServlet): + """ + Gets information about all uploaded local media for a specific `user_id`. + With DELETE request you can delete all this media. + + Example: + http://localhost:8008/_synapse/admin/v1/users/@user:server/media + + Args: + The parameters `from` and `limit` are required for pagination. + By default, a `limit` of 100 is used. + Returns: + A list of media and an integer representing the total number of + media that exist given for this user + """ + + PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") + + def __init__(self, hs: "HomeServer"): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.media_repository = hs.get_media_repository() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only look up local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + + media, total = await self.store.get_local_media_by_user_paginate( + start, limit, user_id, order_by, direction + ) + + ret = {"media": media, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(media) + + return 200, ret + + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only look up local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + + media, _ = await self.store.get_local_media_by_user_paginate( + start, limit, user_id, order_by, direction + ) + + deleted_media, total = await self.media_repository.delete_local_media_ids( + ([row["media_id"] for row in media]) + ) + + return 200, {"deleted_media": deleted_media, "total": total} + + def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: """ Media repo specific APIs. @@ -326,3 +488,4 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) + UserMediaRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 40ee33646cb2..975c28b2258e 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -20,6 +20,7 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.http.servlet import ( + ResolveRoomIdMixin, RestServlet, assert_params_in_dict, parse_integer, @@ -33,7 +34,7 @@ assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -45,48 +46,6 @@ logger = logging.getLogger(__name__) -class ResolveRoomIdMixin: - def __init__(self, hs: "HomeServer"): - self.room_member_handler = hs.get_room_member_handler() - - async def resolve_room_id( - self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None - ) -> Tuple[str, Optional[List[str]]]: - """ - Resolve a room identifier to a room ID, if necessary. - - This also performanes checks to ensure the room ID is of the proper form. - - Args: - room_identifier: The room ID or alias. - remote_room_hosts: The potential remote room hosts to use. - - Returns: - The resolved room ID. - - Raises: - SynapseError if the room ID is of the wrong form. - """ - if RoomID.is_valid(room_identifier): - resolved_room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - ( - room_id, - remote_room_hosts, - ) = await self.room_member_handler.lookup_room_alias(room_alias) - resolved_room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - if not resolved_room_id: - raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier - ) - return resolved_room_id, remote_room_hosts - - class ShutdownRoomRestServlet(RestServlet): """Shuts down a room by removing all local users from the room and blocking all future invites and joins to the room. Any local aliases will be repointed diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py new file mode 100644 index 000000000000..2bf1472967dd --- /dev/null +++ b/synapse/rest/admin/username_available.py @@ -0,0 +1,51 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UsernameAvailableRestServlet(RestServlet): + """An admin API to check if a given username is available, regardless of whether registration is enabled. + + Example: + GET /_synapse/admin/v1/username_available?username=foo + 200 OK + { + "available": true + } + """ + + PATTERNS = admin_patterns("/username_available") + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.registration_handler = hs.get_registration_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + username = parse_string(request, "username", required=True) + await self.registration_handler.check_username(username) + return HTTPStatus.OK, {"available": True} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index eef76ab18a94..3c8a0c6883dc 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -34,8 +34,7 @@ assert_requester_is_admin, assert_user_is_admin, ) -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.storage.databases.main.media_repository import MediaSortOrder +from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID @@ -172,7 +171,7 @@ async def on_GET( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) @@ -196,20 +195,57 @@ async def on_PUT( user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() + # check for required parameters for each threepid + threepids = body.get("threepids") + if threepids is not None: + for threepid in threepids: + assert_params_in_dict(threepid, ["medium", "address"]) + + # check for required parameters for each external_id + external_ids = body.get("external_ids") + if external_ids is not None: + for external_id in external_ids: + assert_params_in_dict(external_id, ["auth_provider", "external_id"]) + + user_type = body.get("user_type", None) + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + set_admin_to = body.get("admin", False) + if not isinstance(set_admin_to, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'admin' must be a boolean, if given", + Codes.BAD_JSON, + ) + + password = body.get("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + + deactivate = body.get("deactivated", False) + if not isinstance(deactivate, bool): + raise SynapseError(400, "'deactivated' parameter is not of type boolean") + + # convert into List[Tuple[str, str]] + if external_ids is not None: + new_external_ids = [] + for external_id in external_ids: + new_external_ids.append( + (external_id["auth_provider"], external_id["external_id"]) + ) + if user: # modify user if "displayname" in body: await self.profile_handler.set_displayname( target_user, requester, body["displayname"], True ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: # remove old threepids from user - threepids = await self.store.user_get_threepids(user_id) - for threepid in threepids: + old_threepids = await self.store.user_get_threepids(user_id) + for threepid in old_threepids: try: await self.auth_handler.delete_threepid( user_id, threepid["medium"], threepid["address"], None @@ -220,18 +256,39 @@ async def on_PUT( # add new threepids to user current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if external_ids is not None: + # get changed external_ids (added and removed) + cur_external_ids = await self.store.get_external_ids_by_user(user_id) + add_external_ids = set(new_external_ids) - set(cur_external_ids) + del_external_ids = set(cur_external_ids) - set(new_external_ids) + + # remove old external_ids + for auth_provider, external_id in del_external_ids: + await self.store.remove_user_external_id( + auth_provider, + external_id, + user_id, + ) + + # add new external_ids + for auth_provider, external_id in add_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True ) if "admin" in body: - set_admin_to = bool(body["admin"]) if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: @@ -239,29 +296,18 @@ async def on_PUT( await self.store.set_server_admin(target_user, set_admin_to) - if "password" in body: - if not isinstance(body["password"], str) or len(body["password"]) > 512: - raise SynapseError(400, "Invalid password") - else: - new_password = body["password"] - logout_devices = True - - new_password_hash = await self.auth_handler.hash(new_password) - - await self.set_password_handler.set_password( - target_user.to_string(), - new_password_hash, - logout_devices, - requester, - ) + if password is not None: + logout_devices = True + new_password_hash = await self.auth_handler.hash(password) + + await self.set_password_handler.set_password( + target_user.to_string(), + new_password_hash, + logout_devices, + requester, + ) if "deactivated" in body: - deactivate = body["deactivated"] - if not isinstance(deactivate, bool): - raise SynapseError( - 400, "'deactivated' parameter is not of type boolean" - ) - if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False, requester, by_admin=True @@ -285,36 +331,24 @@ async def on_PUT( return 200, user else: # create user - password = body.get("password") + displayname = body.get("displayname", None) + password_hash = None if password is not None: - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) - admin = body.get("admin", None) - user_type = body.get("user_type", None) - displayname = body.get("displayname", None) - - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") - user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, - admin=bool(admin), + admin=set_admin_to, default_display_name=displayname, user_type=user_type, by_admin=True, ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) @@ -334,6 +368,14 @@ async def on_PUT( data={}, ) + if external_ids is not None: + for auth_provider, external_id in new_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True @@ -461,7 +503,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication - from synapse.rest.client.v2_alpha.register import RegisterRestServlet + from synapse.rest.client.register import RegisterRestServlet register = RegisterRestServlet(self.hs) @@ -796,7 +838,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -808,97 +850,6 @@ async def on_GET( return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} -class UserMediaRestServlet(RestServlet): - """ - Gets information about all uploaded local media for a specific `user_id`. - - Example: - http://localhost:8008/_synapse/admin/v1/users/ - @user:server/media - - Args: - The parameters `from` and `limit` are required for pagination. - By default, a `limit` of 100 is used. - Returns: - A list of media and an integer representing the total number of - media that exist given for this user - """ - - PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") - - def __init__(self, hs: "HomeServer"): - self.is_mine = hs.is_mine - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - # This will always be set by the time Twisted calls us. - assert request.args is not None - - await assert_requester_is_admin(self.auth, request) - - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - - start = parse_integer(request, "from", default=0) - limit = parse_integer(request, "limit", default=100) - - if start < 0: - raise SynapseError( - 400, - "Query parameter from must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - if limit < 0: - raise SynapseError( - 400, - "Query parameter limit must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - # If neither `order_by` nor `dir` is set, set the default order - # to newest media is on top for backward compatibility. - if b"order_by" not in request.args and b"dir" not in request.args: - order_by = MediaSortOrder.CREATED_TS.value - direction = "b" - else: - order_by = parse_string( - request, - "order_by", - default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), - ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") - ) - - media, total = await self.store.get_local_media_by_user_paginate( - start, limit, user_id, order_by, direction - ) - - ret = {"media": media, "total": total} - if (start + limit) < total: - ret["next_token"] = start + len(media) - - return 200, ret - - class UserTokenRestServlet(RestServlet): """An admin API for logging in as a user. @@ -1017,7 +968,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") diff --git a/synapse/rest/client/__init__.py b/synapse/rest/client/__init__.py index 629e2df74a4f..f9830cc51f84 100644 --- a/synapse/rest/client/__init__.py +++ b/synapse/rest/client/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2014-2016 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/_base.py similarity index 100% rename from synapse/rest/client/v2_alpha/_base.py rename to synapse/rest/client/_base.py diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/account.py similarity index 100% rename from synapse/rest/client/v2_alpha/account.py rename to synapse/rest/client/account.py diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/account_data.py similarity index 100% rename from synapse/rest/client/v2_alpha/account_data.py rename to synapse/rest/client/account_data.py diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/account_validity.py similarity index 100% rename from synapse/rest/client/v2_alpha/account_validity.py rename to synapse/rest/client/account_validity.py diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/auth.py similarity index 100% rename from synapse/rest/client/v2_alpha/auth.py rename to synapse/rest/client/auth.py diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/capabilities.py similarity index 100% rename from synapse/rest/client/v2_alpha/capabilities.py rename to synapse/rest/client/capabilities.py diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/devices.py similarity index 100% rename from synapse/rest/client/v2_alpha/devices.py rename to synapse/rest/client/devices.py diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/directory.py similarity index 98% rename from synapse/rest/client/v1/directory.py rename to synapse/rest/client/directory.py index ae92a3df8e35..ffa075c8e5f6 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/directory.py @@ -23,7 +23,7 @@ SynapseError, ) from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.types import RoomAlias logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/events.py similarity index 98% rename from synapse/rest/client/v1/events.py rename to synapse/rest/client/events.py index ee7454996e5a..52bb579cfd40 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/events.py @@ -17,7 +17,7 @@ from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/filter.py similarity index 100% rename from synapse/rest/client/v2_alpha/filter.py rename to synapse/rest/client/filter.py diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/groups.py similarity index 100% rename from synapse/rest/client/v2_alpha/groups.py rename to synapse/rest/client/groups.py diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/initial_sync.py similarity index 96% rename from synapse/rest/client/v1/initial_sync.py rename to synapse/rest/client/initial_sync.py index bef1edc838ab..12ba0e91dbd1 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -14,7 +14,7 @@ from synapse.http.servlet import RestServlet, parse_boolean -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/keys.py similarity index 100% rename from synapse/rest/client/v2_alpha/keys.py rename to synapse/rest/client/keys.py diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/knock.py similarity index 100% rename from synapse/rest/client/v2_alpha/knock.py rename to synapse/rest/client/knock.py diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/login.py similarity index 99% rename from synapse/rest/client/v1/login.py rename to synapse/rest/client/login.py index 11567bf32cef..0c8d8967b7ee 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/login.py @@ -34,7 +34,7 @@ parse_string, ) from synapse.http.site import SynapseRequest -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/logout.py similarity index 97% rename from synapse/rest/client/v1/logout.py rename to synapse/rest/client/logout.py index 5aa7908d73a6..6055cac2bd0a 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/logout.py @@ -15,7 +15,7 @@ import logging from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/notifications.py similarity index 100% rename from synapse/rest/client/v2_alpha/notifications.py rename to synapse/rest/client/notifications.py diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/openid.py similarity index 100% rename from synapse/rest/client/v2_alpha/openid.py rename to synapse/rest/client/openid.py diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/password_policy.py similarity index 100% rename from synapse/rest/client/v2_alpha/password_policy.py rename to synapse/rest/client/password_policy.py diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/presence.py similarity index 98% rename from synapse/rest/client/v1/presence.py rename to synapse/rest/client/presence.py index 2b24fe5aa65f..6c27e5faf986 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/presence.py @@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.types import UserID logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/profile.py similarity index 98% rename from synapse/rest/client/v1/profile.py rename to synapse/rest/client/profile.py index f42f4b35674f..5463ed2c4f85 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/profile.py @@ -16,7 +16,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.types import UserID diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/push_rule.py similarity index 99% rename from synapse/rest/client/v1/push_rule.py rename to synapse/rest/client/push_rule.py index be29a0b39ec6..702b351d183c 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -26,7 +26,7 @@ from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/pusher.py similarity index 98% rename from synapse/rest/client/v1/pusher.py rename to synapse/rest/client/pusher.py index 18102eca6c1b..84619c5e4184 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/pusher.py @@ -23,7 +23,7 @@ parse_string, ) from synapse.push import PusherConfigException -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/read_marker.py similarity index 100% rename from synapse/rest/client/v2_alpha/read_marker.py rename to synapse/rest/client/read_marker.py diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/receipts.py similarity index 100% rename from synapse/rest/client/v2_alpha/receipts.py rename to synapse/rest/client/receipts.py diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/register.py similarity index 99% rename from synapse/rest/client/v2_alpha/register.py rename to synapse/rest/client/register.py index 4d31584acd20..58b8e8f2614f 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/register.py @@ -115,7 +115,7 @@ async def on_POST(self, request): # For emails, canonicalise the address. # We store all email addresses canonicalised in the DB. # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/v2_alpha/account.py) + # in synapse/rest/client/account.py) try: email = validate_email(body["email"]) except ValueError as e: @@ -631,7 +631,7 @@ async def on_POST(self, request): # For emails, canonicalise the address. # We store all email addresses canonicalised in the DB. # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/v2_alpha/account.py) + # in synapse/rest/client/account.py) if medium == "email": try: address = canonicalise_email(address) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/relations.py similarity index 100% rename from synapse/rest/client/v2_alpha/relations.py rename to synapse/rest/client/relations.py diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/report_event.py similarity index 100% rename from synapse/rest/client/v2_alpha/report_event.py rename to synapse/rest/client/report_event.py diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/room.py similarity index 65% rename from synapse/rest/client/v1/room.py rename to synapse/rest/client/room.py index 502a91758813..c5c54564bed3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/room.py @@ -19,19 +19,19 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from urllib import parse as urlparse -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, InvalidClientCredentialsError, + MissingClientTokenError, ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter -from synapse.appservice import ApplicationService from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( + ResolveRoomIdMixin, RestServlet, assert_params_in_dict, parse_boolean, @@ -42,20 +42,11 @@ ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag +from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import ( - JsonDict, - Requester, - RoomAlias, - RoomID, - StreamToken, - ThirdPartyInstanceID, - UserID, - create_requester, -) +from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -269,405 +260,11 @@ def on_PUT(self, request, room_id, event_type, txn_id): ) -class RoomBatchSendEventRestServlet(TransactionRestServlet): - """ - API endpoint which can insert a chunk of events historically back in time - next to the given `prev_event`. - - `chunk_id` comes from `next_chunk_id `in the response of the batch send - endpoint and is derived from the "insertion" events added to each chunk. - It's not required for the first batch send. - - `state_events_at_start` is used to define the historical state events - needed to auth the events like join events. These events will float - outside of the normal DAG as outlier's and won't be visible in the chat - history which also allows us to insert multiple chunks without having a bunch - of `@mxid joined the room` noise between each chunk. - - `events` is chronological chunk/list of events you want to insert. - There is a reverse-chronological constraint on chunks so once you insert - some messages, you can only insert older ones after that. - tldr; Insert chunks from your most recent history -> oldest history. - - POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= - { - "events": [ ... ], - "state_events_at_start": [ ... ] - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2716" - "/rooms/(?P[^/]*)/batch_send$" - ), - ) - - def __init__(self, hs): - super().__init__(hs) - self.hs = hs - self.store = hs.get_datastore() - self.state_store = hs.get_storage().state - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ): - """Creates an event dict for an "insertion" event with the proper fields - and a random chunk ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - Tuple of event ID and stream ordering position - """ - - next_chunk_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - if not requester.app_service: - raise AuthError( - 403, - "Only application services can use the /batchsend endpoint", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["state_events_at_start", "events"]) - - prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id") - - if prev_events_from_query is None: - raise SynapseError( - 400, - "prev_event query parameter is required when inserting historical messages back in time", - errcode=Codes.MISSING_PARAM, - ) - - # For the event we are inserting next to (`prev_events_from_query`), - # find the most recent auth events (derived from state events) that - # allowed that message to be sent. We will use that as a base - # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_events_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, - ) - - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( - prev_events_from_query - ) - - # Figure out which chunk to connect to. If they passed in - # chunk_id_from_query let's use it. The chunk ID passed in comes - # from the chunk_id in the "insertion" event from the previous chunk. - last_event_in_chunk = events_to_create[-1] - chunk_id_to_connect_to = chunk_id_from_query - base_insertion_event = None - if chunk_id_from_query: - # All but the first base insertion event should point at a fake - # event, which causes the HS to ask for the state at the start of - # the chunk later. - prev_event_ids = [fake_prev_event_id] - # TODO: Verify the chunk_id_from_query corresponds to an insertion event - pass - # Otherwise, create an insertion event to act as a starting point. - # - # We don't always have an insertion event to start hanging more history - # off of (ideally there would be one in the main DAG, but that's not the - # case if we're wanting to add history to e.g. existing rooms without - # an insertion event), in which case we just create a new insertion event - # that can then get pointed to by a "marker" event later. - else: - prev_event_ids = prev_events_from_query - - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_chunk["origin_server_ts"], - ) - base_insertion_event_dict["prev_events"] = prev_event_ids.copy() - - ( - base_insertion_event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - base_insertion_event_dict["sender"], - requester.app_service, - ), - base_insertion_event_dict, - prev_event_ids=base_insertion_event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - - chunk_id_to_connect_to = base_insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ] - - # Connect this current chunk to the insertion event from the previous chunk - chunk_event = { - "type": EventTypes.MSC2716_CHUNK, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to}, - # Since the chunk event is put at the end of the chunk, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_chunk["origin_server_ts"], - } - # Add the chunk event to the end of the chunk (newest-in-time) - events_to_create.append(chunk_event) - - # Add an "insertion" event to the start of each chunk (next to the oldest-in-time - # event in the chunk) so the next chunk can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - # Since the insertion event is put at the start of the chunk, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], - ) - # Prepend the insertion event to the start of the chunk (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - # Mark all events as historical - # This has important semantics within the Synapse internals to backfill properly - ev["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) - - # Add the base_insertion_event to the bottom of the list we return - if base_insertion_event is not None: - event_ids.append(base_insertion_event.event_id) - - return 200, { - "state_events": auth_event_ids, - "events": event_ids, - "next_chunk_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ], - } - - def on_GET(self, request, room_id): - return 501, "Not implemented" - - def on_PUT(self, request, room_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - - # TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(TransactionRestServlet): +class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() + super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() def register(self, http_server): @@ -690,24 +287,13 @@ async def on_POST( # cheekily send invalid bodies. content = {} - if RoomID.is_valid(room_identifier): - room_id = room_identifier - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - - remote_room_hosts = parse_strings_from_args( - args, "server_name", required=False - ) - elif RoomAlias.is_valid(room_identifier): - handler = self.room_member_handler - room_alias = RoomAlias.from_string(room_identifier) - room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id_obj.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) await self.room_member_handler.update_membership( requester=requester, @@ -778,12 +364,9 @@ async def on_GET(self, request): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) else: data = await handler.get_local_public_room_list( limit=limit, since_token=since_token @@ -831,17 +414,15 @@ async def on_POST(self, request): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + else: data = await handler.get_local_public_room_list( limit=limit, @@ -1405,18 +986,26 @@ class RoomSpaceSummaryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._space_summary_handler = hs.get_space_summary_handler() + self._room_summary_handler = hs.get_room_summary_handler() async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request, allow_guest=True) - return 200, await self._space_summary_handler.get_space_summary( + max_rooms_per_space = parse_integer(request, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=parse_integer(request, "max_rooms_per_space"), + max_rooms_per_space=max_rooms_per_space, ) # TODO When switching to the stable endpoint, remove the POST handler. @@ -1433,12 +1022,19 @@ async def on_POST( ) max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) - return 200, await self._space_summary_handler.get_space_summary( + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=suggested_only, @@ -1446,9 +1042,85 @@ async def on_POST( ) -def register_servlets(hs: "HomeServer", http_server, is_worker=False): - msc2716_enabled = hs.config.experimental.msc2716_enabled +class RoomHierarchyRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/hierarchy$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_depth = parse_integer(request, "max_depth") + if max_depth is not None and max_depth < 0: + raise SynapseError( + 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON + ) + + limit = parse_integer(request, "limit") + if limit is not None and limit <= 0: + raise SynapseError( + 400, "'limit' must be a positive integer", Codes.BAD_JSON + ) + + return 200, await self._room_summary_handler.get_room_hierarchy( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_depth=max_depth, + limit=limit, + from_token=parse_string(request, "from"), + ) + + +class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/im.nheko.summary" + "/rooms/(?P[^/]*)/summary$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: + try: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + requester_user_id: Optional[str] = requester.user.to_string() + except MissingClientTokenError: + # auth is optional + requester_user_id = None + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + return 200, await self._room_summary_handler.get_room_summary( + requester_user_id, + room_id, + remote_room_hosts, + ) + + +def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1456,22 +1128,23 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): JoinRoomAliasServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) - if msc2716_enabled: - RoomBatchSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) RoomSpaceSummaryRestServlet(hs).register(http_server) + RoomHierarchyRestServlet(hs).register(http_server) + if hs.config.experimental.msc3266_enabled: + RoomSummaryRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) + RoomCreateRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: - RoomCreateRestServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py new file mode 100644 index 000000000000..3172aba60563 --- /dev/null +++ b/synapse/rest/client/room_batch.py @@ -0,0 +1,441 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.appservice import ApplicationService +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class RoomBatchSendEventRestServlet(RestServlet): + """ + API endpoint which can insert a chunk of events historically back in time + next to the given `prev_event`. + + `chunk_id` comes from `next_chunk_id `in the response of the batch send + endpoint and is derived from the "insertion" events added to each chunk. + It's not required for the first batch send. + + `state_events_at_start` is used to define the historical state events + needed to auth the events like join events. These events will float + outside of the normal DAG as outlier's and won't be visible in the chat + history which also allows us to insert multiple chunks without having a bunch + of `@mxid joined the room` noise between each chunk. + + `events` is chronological chunk/list of events you want to insert. + There is a reverse-chronological constraint on chunks so once you insert + some messages, you can only insert older ones after that. + tldr; Insert chunks from your most recent history -> oldest history. + + POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= + { + "events": [ ... ], + "state_events_at_start": [ ... ] + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2716" + "/rooms/(?P[^/]*)/batch_send$" + ), + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + if not requester.app_service: + raise AuthError( + 403, + "Only application services can use the /batchsend endpoint", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["state_events_at_start", "events"]) + + prev_events_from_query = parse_strings_from_args(request.args, "prev_event") + chunk_id_from_query = parse_string(request, "chunk_id") + + if prev_events_from_query is None: + raise SynapseError( + 400, + "prev_event query parameter is required when inserting historical messages back in time", + errcode=Codes.MISSING_PARAM, + ) + + # For the event we are inserting next to (`prev_events_from_query`), + # find the most recent auth events (derived from state events) that + # allowed that message to be sent. We will use that as a base + # to auth our historical messages against. + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(prev_events_from_query) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + state_events_at_start = [] + for state_event in body["state_events_at_start"]: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_events_at_start.append(event_id) + auth_event_ids.append(event_id) + + events_to_create = body["events"] + + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None + if chunk_id_from_query: + # All but the first base insertion event should point at a fake + # event, which causes the HS to ask for the state at the start of + # the chunk later. + prev_event_ids = [fake_prev_event_id] + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + prev_event_ids = prev_events_from_query + + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] + + # Connect this current chunk to the insertion event from the previous chunk + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time + # event in the chunk) so the next chunk can be connected to this one. + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the chunk, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + # Prepend the insertion event to the start of the chunk (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + ev = await self.event_creation_handler.handle_new_client_event( + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), + event=event, + context=context, + ) + + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + + return 200, { + "state_events": state_events_at_start, + "events": event_ids, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], + } + + def on_GET(self, request, room_id): + return 501, "Not implemented" + + def on_PUT(self, request, room_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id + ) + + +def register_servlets(hs, http_server): + msc2716_enabled = hs.config.experimental.msc2716_enabled + + if msc2716_enabled: + RoomBatchSendEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/room_keys.py similarity index 100% rename from synapse/rest/client/v2_alpha/room_keys.py rename to synapse/rest/client/room_keys.py diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py similarity index 100% rename from synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py rename to synapse/rest/client/room_upgrade_rest_servlet.py diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/sendtodevice.py similarity index 100% rename from synapse/rest/client/v2_alpha/sendtodevice.py rename to synapse/rest/client/sendtodevice.py diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/shared_rooms.py similarity index 100% rename from synapse/rest/client/v2_alpha/shared_rooms.py rename to synapse/rest/client/shared_rooms.py diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/sync.py similarity index 98% rename from synapse/rest/client/v2_alpha/sync.py rename to synapse/rest/client/sync.py index e321668698ae..e18f4d01b375 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/sync.py @@ -259,10 +259,11 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count - if sync_result.device_unused_fallback_key_types: - response[ - "org.matrix.msc2732.device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types + # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + # states that this field should always be included, as long as the server supports the feature. + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/tags.py similarity index 100% rename from synapse/rest/client/v2_alpha/tags.py rename to synapse/rest/client/tags.py diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/thirdparty.py similarity index 100% rename from synapse/rest/client/v2_alpha/thirdparty.py rename to synapse/rest/client/thirdparty.py diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py similarity index 100% rename from synapse/rest/client/v2_alpha/tokenrefresh.py rename to synapse/rest/client/tokenrefresh.py diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/user_directory.py similarity index 100% rename from synapse/rest/client/v2_alpha/user_directory.py rename to synapse/rest/client/user_directory.py diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed6f..000000000000 --- a/synapse/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py deleted file mode 100644 index 5e83dba2ed6f..000000000000 --- a/synapse/rest/client/v2_alpha/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/voip.py similarity index 97% rename from synapse/rest/client/v1/voip.py rename to synapse/rest/client/voip.py index c780ffded5e9..f53020520d37 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/voip.py @@ -17,7 +17,7 @@ import hmac from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns class VoipRestServlet(RestServlet): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 4f702f890c1c..0f5ce41ff880 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -836,7 +836,9 @@ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: return {"deleted": deleted} - async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: """ Delete the given local or remote media ID from this server @@ -845,7 +847,7 @@ async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: Returns: A tuple of (list of deleted media IDs, total deleted media IDs). """ - return await self._remove_local_media_from_disk([media_id]) + return await self._remove_local_media_from_disk(media_ids) async def delete_old_local_media( self, diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 488b97b32e02..fc62a09b7f07 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -46,6 +46,8 @@ def __init__(self, hs: "HomeServer"): self._consent_version = hs.config.consent.user_consent_version def template_search_dirs(): + if hs.config.server.custom_template_directory: + yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index ab24ec0a8e68..c15b83c387c2 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -74,6 +74,8 @@ def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() def template_search_dirs(): + if hs.config.server.custom_template_directory: + yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir diff --git a/synapse/server.py b/synapse/server.py index 095dba9ad038..de6517663e6b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,8 +34,6 @@ ) import twisted.internet.tcp -from twisted.internet import defer -from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS from twisted.web.resource import IResource @@ -101,10 +99,10 @@ from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.room_summary import RoomSummaryHandler from synapse.handlers.search import SearchHandler from synapse.handlers.send_email import SendEmailHandler from synapse.handlers.set_password import SetPasswordHandler -from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -442,10 +440,6 @@ def get_room_creation_handler(self) -> RoomCreationHandler: def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) - @cache_in_self - def get_sendmail(self) -> Callable[..., defer.Deferred]: - return sendmail - @cache_in_self def get_state_handler(self) -> StateHandler: return StateHandler(self) @@ -778,8 +772,8 @@ def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) @cache_in_self - def get_space_summary_handler(self) -> SpaceSummaryHandler: - return SpaceSummaryHandler(self) + def get_room_summary_handler(self) -> RoomSummaryHandler: + return RoomSummaryHandler(self) @cache_in_self def get_event_auth_handler(self) -> EventAuthHandler: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c8015a384857..95d2caff628c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -941,13 +941,13 @@ async def simple_upsert( `lock` should generally be set to True (the default), but can be set to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. + 1. there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + 2. we somehow know that we are the only thread which will be updating + this table. + As an additional note, this parameter only matters for old SQLite versions + because we will use native upserts otherwise. Args: table: The table to upsert into diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 8d9f07111db5..01b918e12e10 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -127,9 +127,6 @@ def __init__(self, database: DatabasePool, db_conn, hs): self._clock = hs.get_clock() self.database_engine = database.engine - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", @@ -170,6 +167,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): sequence_name="cache_invalidation_stream_seq", writers=[], ) + else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1edc96042bbe..1f0a39eac41e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -755,81 +755,145 @@ async def claim_e2e_one_time_keys( """ @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + def _claim_e2e_one_time_key_simple( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, device_id, algorithm)) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + + self.db_pool.simple_delete_one_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, ) - fallback_sql = ( - "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - result = {} - delete = [] - used_fallbacks = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is not None: - key_id, key_json = otk_row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - else: - # no one-time key available, so see if there's a fallback - # key - txn.execute(fallback_sql, (user_id, device_id, algorithm)) - fallback_row = txn.fetchone() - if fallback_row is not None: - key_id, key_json, used = fallback_row - device_result[algorithm + ":" + key_id] = key_json - if not used: - used_fallbacks.append( - (user_id, device_id, algorithm, key_id) - ) - - # drop any one-time keys that were claimed - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" + + return f"{algorithm}:{key_id}", key_json + + @trace + def _claim_e2e_one_time_key_returning( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + return f"{algorithm}:{key_id}", key_json + + results = {} + for user_id, device_id, algorithm in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + row = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + db_autocommit=db_autocommit, + ) + if row: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - # mark fallback keys as used - for user_id, device_id, algorithm, key_id in used_fallbacks: - self.db_pool.simple_update_txn( - txn, - "e2e_fallback_keys_json", - { + device_results[row[0]] = row[1] + continue + + # No one-time key available, so see if there's a fallback + # key + row = await self.db_pool.simple_select_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcols=("key_id", "key_json", "used"), + desc="_get_fallback_key", + allow_none=True, + ) + if row is None: + continue + + key_id = row["key_id"] + key_json = row["key_json"] + used = row["used"] + + # Mark fallback key as used if not already. + if not used: + await self.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, - {"used": True}, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) ) - return result + device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) + device_results[f"{algorithm}:{key_id}"] = key_json - return await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + return results class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 44018c1c31ab..bddf5ef19279 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -671,27 +671,97 @@ def _get_auth_chain_difference_txn( # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - async def get_oldest_events_with_depth_in_room(self, room_id): + async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: + """Gets the oldest events(backwards extremities) in the room along with the + aproximate depth. + + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + event extremeties. If the current depth is close to the depth of given + oldest event, we can trigger a backfill. + + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + # Assemble a dictionary with event_id -> depth for the oldest events + # we know of in the room. Backwards extremeties are the oldest + # events we know of in the room but we only know of them because + # some other event referenced them by prev_event and aren't peristed + # in our database yet (meaning we don't know their depth + # specifically). So we need to look for the aproximate depth from + # the events connected to the current backwards extremeties. + sql = """ + SELECT b.event_id, MAX(e.depth) FROM events as e + /** + * Get the edge connections from the event_edges table + * so we can see whether this event's prev_events points + * to a backward extremity in the next join. + */ + INNER JOIN event_edges as g + ON g.event_id = e.event_id + /** + * We find the "oldest" events in the room by looking for + * events connected to backwards extremeties (oldest events + * in the room that we know of so far). + */ + INNER JOIN event_backward_extremities as b + ON g.prev_event_id = b.event_id + WHERE b.room_id = ? AND g.is_state is ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id, False)) + + return dict(txn) + return await self.db_pool.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, + "get_oldest_event_ids_with_depth_in_room", + get_oldest_event_ids_with_depth_in_room_txn, room_id, ) - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) + async def get_insertion_event_backwards_extremities_in_room( + self, room_id + ) -> Dict[str, int]: + """Get the insertion events we know about that we haven't backfilled yet. - txn.execute(sql, (room_id, False)) + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + insertion event. If the current depth is close to the depth of given + insertion event, we can trigger a backfill. - return dict(txn) + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + sql = """ + SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + /* We only want insertion events that are also marked as backwards extremities */ + INNER JOIN insertion_event_extremities as b USING (event_id) + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE b.room_id = ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id,)) + + return dict(txn) + + return await self.db_pool.runInteraction( + "get_insertion_event_backwards_extremities_in_room", + get_insertion_event_backwards_extremities_in_room_txn, + room_id, + ) async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs @@ -1041,7 +1111,6 @@ def _get_backfill_events(self, txn, room_id, event_list, limit): if row[1] not in event_results: queue.put((-row[0], row[1])) - # Navigate up the DAG by prev_event txn.execute(query, (event_id, False, limit - len(event_results))) prev_event_id_results = txn.fetchall() logger.debug( @@ -1136,6 +1205,19 @@ def _delete_old_forward_extrem_cache_txn(txn): _delete_old_forward_extrem_cache_txn, ) + async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: + await self.db_pool.simple_upsert( + table="insertion_event_extremities", + keyvalues={"event_id": event_id}, + values={ + "event_id": event_id, + "room_id": room_id, + }, + insertion_values={}, + desc="insert_insertion_extremity", + lock=False, + ) + async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 86baf397fbda..40b53274fb3d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1845,6 +1845,18 @@ def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase): }, ) + # When we receive an event with a `chunk_id` referencing the + # `next_chunk_id` of the insertion event, we can remove it from the + # `insertion_event_extremities` table. + sql = """ + DELETE FROM insertion_event_extremities WHERE event_id IN ( + SELECT event_id FROM insertion_events + WHERE next_chunk_id = ? + ) + """ + + txn.execute(sql, (chunk_id,)) + def _handle_redaction(self, txn, redacted_event_id): """Handles receiving a redaction and checking whether we need to remove any redacted relations from the database. @@ -2101,15 +2113,17 @@ def _update_backward_extremeties(self, txn, events): Forward extremities are handled when we first start persisting the events. """ + # From the events passed in, add all of the prev events as backwards extremities. + # Ignore any events that are already backwards extrems or outliers. query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" " )" " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" " )" ) @@ -2123,6 +2137,8 @@ def _update_backward_extremeties(self, txn, events): ], ) + # Delete all these events that we've already fetched and now know that their + # prev events are the new backwards extremeties. query = ( "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3c86adab5650..375463e4e979 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ def __init__(self, database: DatabasePool, db_conn, hs): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ async def get_events_as_list( return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,107 @@ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + already_fetching: Dict[str, defer.Deferred] = {} + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching[event_id] = deferred.observe() + + missing_events_ids.difference_update(already_fetching) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) - event_entry_map.update(missing_events) + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + already_fetching.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + event_entry_map.update(result) + + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +614,7 @@ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +741,23 @@ def fire(evs, exc): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +786,6 @@ async def _get_events_from_db(self, event_ids, allow_rejected=False): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7fbb..c67bea81c6b9 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ def __init__( @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config @@ -571,6 +599,28 @@ async def record_user_external_id( desc="record_user_external_id", ) + async def remove_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Remove a mapping from an external user id to a mxid + + If the mapping is not found, this method does nothing. + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_delete( + table="user_external_ids", + keyvalues={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="remove_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 443e5f331545..f98b89259892 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,6 +73,40 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.config = hs.config + async def store_room( + self, + room_id: str, + room_creator_user_id: str, + is_public: bool, + room_version: RoomVersion, + ): + """Stores a room. + + Args: + room_id: The desired room ID, can be None. + room_creator_user_id: The user ID of the room creator. + is_public: True to indicate that this room should appear in + public room lists. + room_version: The version of the room + Raises: + StoreError if the room could not be stored. + """ + try: + await self.db_pool.simple_insert( + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + "has_auth_chain_index": True, + }, + desc="store_room", + ) + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + async def get_room(self, room_id: str) -> dict: """Retrieve a room. @@ -890,55 +924,6 @@ def _quarantine_media_txn( return total_media_quarantined - async def get_all_new_public_rooms( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for public rooms replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - if last_id == current_id: - return [], current_id, False - - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db_pool.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False ) -> Dict[str, dict]: @@ -1391,57 +1376,6 @@ async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): lock=False, ) - async def store_room( - self, - room_id: str, - room_creator_user_id: str, - is_public: bool, - room_version: RoomVersion, - ): - """Stores a room. - - Args: - room_id: The desired room ID, can be None. - room_creator_user_id: The user ID of the room creator. - is_public: True to indicate that this room should appear in - public room lists. - room_version: The version of the room - Raises: - StoreError if the room could not be stored. - """ - try: - - def store_room_txn(txn, next_id): - self.db_pool.simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - "room_version": room_version.identifier, - "has_auth_chain_index": True, - }, - ) - if is_public: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "store_room_txn", store_room_txn, next_id - ) - except Exception as e: - logger.error("store_room with room_id=%s failed: %s", room_id, e) - raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ): @@ -1470,49 +1404,14 @@ async def maybe_store_room_on_outlier_membership( lock=False, ) - async def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self.db_pool.simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) + async def set_room_is_public(self, room_id: str, is_public: bool) -> None: + await self.db_pool.simple_update_one( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="set_room_is_public", + ) - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( @@ -1533,68 +1432,33 @@ async def set_room_is_public_appservice( list. """ - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self.db_pool.simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self.db_pool.simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", + if is_public: + await self.db_pool.simple_upsert( + table="appservice_room_list", keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, "room_id": room_id, + }, + values={}, + insertion_values={ "appservice_id": appservice_id, "network_id": network_id, + "room_id": room_id, }, - retcols=("stream_id", "visibility"), + desc="set_room_is_public_appservice_true", ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, + else: + await self.db_pool.simple_delete( + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + desc="set_room_is_public_appservice_false", ) + self.hs.get_notifier().on_new_replication_data() async def add_event_report( @@ -1787,9 +1651,6 @@ def _get_event_reports_paginate_txn(txn): "get_event_reports_paginate", _get_event_reports_paginate_txn ) - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 68f1b40ea693..e8157ba3d4eb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -629,14 +629,12 @@ async def _get_joined_users_from_context( # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/synapse/storage/schema/README.md b/synapse/storage/schema/README.md index 729f44ea6cf4..4fc2061a3dab 100644 --- a/synapse/storage/schema/README.md +++ b/synapse/storage/schema/README.md @@ -1,4 +1,4 @@ # Synapse Database Schemas This directory contains the schema files used to build Synapse databases. For more -information, see /docs/development/database_schema.md. +information, see https://matrix-org.github.io/synapse/develop/development/database_schema.html. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 36340a652aac..a5bc0ee8a560 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,19 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 61 +SCHEMA_VERSION = 63 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the shape of the database schema (even if those requirements are backwards-compatible with older versions of Synapse). -See `README.md `_ for more information on how this -works. +See https://matrix-org.github.io/synapse/develop/development/database_schema.html +for more information on how this works. Changes in SCHEMA_VERSION = 61: - The `user_stats_historical` and `room_stats_historical` tables are not written and are not read (previously, they were written but not read). + +Changes in SCHEMA_VERSION = 63: + - The `public_room_list_stream` table is not written nor read to + (previously, it was written and read to, but not for any significant purpose). + https://github.com/matrix-org/synapse/pull/10565 """ diff --git a/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql new file mode 100644 index 000000000000..b731ef284ac1 --- /dev/null +++ b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql @@ -0,0 +1,24 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Add a table that keeps track of which "insertion" events need to be backfilled +CREATE TABLE IF NOT EXISTS insertion_event_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_extremities_event_id ON insertion_event_extremities(event_id); +CREATE INDEX IF NOT EXISTS insertion_event_extremities_room_id ON insertion_event_extremities(room_id); diff --git a/synapse/types.py b/synapse/types.py index 429bb013d2cf..80fa903c4bae 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info): # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py deleted file mode 100644 index abc12f08374d..000000000000 --- a/synapse/util/jsonobject.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class JsonEncodedObject: - """A common base class for defining protocol units that are represented - as JSON. - - Attributes: - unrecognized_keys (dict): A dict containing all the key/value pairs we - don't recognize. - """ - - valid_keys = [] # keys we will store - """A list of strings that represent keys we know about - and can handle. If we have values for these keys they will be - included in the `dictionary` instance variable. - """ - - internal_keys = [] # keys to ignore while building dict - """A list of strings that should *not* be encoded into JSON. - """ - - required_keys = [] - """A list of strings that we require to exist. If they are not given upon - construction it raises an exception. - """ - - def __init__(self, **kwargs): - """Takes the dict of `kwargs` and loads all keys that are *valid* - (i.e., are included in the `valid_keys` list) into the dictionary` - instance variable. - - Any keys that aren't recognized are added to the `unrecognized_keys` - attribute. - - Args: - **kwargs: Attributes associated with this protocol unit. - """ - for required_key in self.required_keys: - if required_key not in kwargs: - raise RuntimeError("Key %s is required" % required_key) - - self.unrecognized_keys = {} # Keys we were given not listed as valid - for k, v in kwargs.items(): - if k in self.valid_keys or k in self.internal_keys: - self.__dict__[k] = v - else: - self.unrecognized_keys[k] = v - - def get_dict(self): - """Converts this protocol unit into a :py:class:`dict`, ready to be - encoded as JSON. - - The keys it encodes are: `valid_keys` - `internal_keys` - - Returns - dict - """ - d = { - k: _encode(v) - for (k, v) in self.__dict__.items() - if k in self.valid_keys and k not in self.internal_keys - } - d.update(self.unrecognized_keys) - return d - - def get_internal_dict(self): - d = { - k: _encode(v, internal=True) - for (k, v) in self.__dict__.items() - if k in self.valid_keys - } - d.update(self.unrecognized_keys) - return d - - def __str__(self): - return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) - - -def _encode(obj, internal=False): - if type(obj) is list: - return [_encode(o, internal=internal) for o in obj] - - if isinstance(obj, JsonEncodedObject): - if internal: - return obj.get_internal_dict() - else: - return obj.get_dict() - - return obj diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index da24ba0470b6..522daa323d00 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import sys import traceback @@ -20,6 +21,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal +from twisted.internet import defer PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" @@ -141,3 +143,15 @@ def showtraceback(self): self.write("".join(lines)) finally: last_tb = ei = None + + def displayhook(self, obj): + """ + We override the displayhook so that we automatically convert coroutines + into Deferreds. (Our superclass' displayhook will take care of the rest, + by displaying the Deferred if it's ready, or registering a callback + if it's not). + """ + if inspect.iscoroutine(obj): + super().displayhook(defer.ensureDeferred(obj)) + else: + super().displayhook(obj) diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 5527e278dbc3..d66aeb00eb78 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,6 +1,6 @@ import synapse from synapse.app.phone_stats_home import start_phone_stats_home -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.unittest import HomeserverTestCase diff --git a/tests/config/test_base.py b/tests/config/test_base.py index 84ae3b88ae9b..baa5313fb3cc 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -30,7 +30,7 @@ def test_loading_missing_templates(self): # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +60,7 @@ def test_loading_custom_templates(self): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], tmp_dir) + self.hs.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -74,8 +74,66 @@ def test_loading_custom_templates(self): "Template file did not contain our test string", ) + def test_multiple_custom_template_directories(self): + """Tests that directories are searched in the right order if multiple custom + template directories are provided. + """ + # Create two temporary directories on the filesystem. + tempdirs = [ + tempfile.TemporaryDirectory(), + tempfile.TemporaryDirectory(), + ] + + # Create one template in each directory, whose content is the index of the + # directory in the list. + template_filename = "my_template.html.j2" + for i in range(len(tempdirs)): + tempdir = tempdirs[i] + template_path = os.path.join(tempdir.name, template_filename) + + with open(template_path, "w") as fp: + fp.write(str(i)) + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [template_filename], + (td.name for td in tempdirs), + ) + )[0] + + # Test that we got the template we dropped in the first directory in the list. + self.assertEqual(template.render(), "0") + + # Add another template, this one only in the second directory in the list, so we + # can test that the second directory is still searched into when no matching file + # could be found in the first one. + other_template_name = "my_other_template.html.j2" + other_template_path = os.path.join(tempdirs[1].name, other_template_name) + + with open(other_template_path, "w") as fp: + fp.write("hello world") + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [other_template_name], + (td.name for td in tempdirs), + ) + )[0] + + # Test that the file has the expected content. + self.assertEqual(template.render(), "hello world") + + # Cleanup the temporary directories manually since we're not using a context + # manager. + for td in tempdirs: + td.cleanup() + def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): self.hs.config.read_templates( - ["some_filename.html"], "a_nonexistent_directory" + ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3f41e9995039..6b87f571b8dc 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -22,7 +22,7 @@ from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import JsonDict, StreamToken, create_requester from tests.handlers.test_sync import generate_sync_config diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 48e98aac797d..ca27388ae8a3 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -14,7 +14,7 @@ from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.test_utils.event_injection import create_event diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1a809b2a6ae0..7b486aba4a04 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -16,7 +16,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID from tests import unittest diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 802c5ad299d3..f0aa8ed9db4d 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -6,7 +6,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b00dd143d677..65b18fbd7a14 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -21,7 +21,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt from tests.test_utils import make_awaitable diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 173789156459..0b60cc426119 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -19,7 +19,7 @@ from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index aab44bce4a79..383214ab5046 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -18,7 +18,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import builder from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import RoomAlias diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 18a734daf461..59de1142b157 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -15,12 +15,10 @@ from collections import Counter from unittest.mock import Mock -import synapse.api.errors -import synapse.handlers.admin import synapse.rest.admin import synapse.storage from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 024c5e963cd2..43998020b2eb 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -133,11 +133,131 @@ def test_query_room_alias_exists(self): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7a8041ab4437..a0a48b564eb1 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -19,7 +19,7 @@ import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 4140fcefc2c4..c72a8972a3f1 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -22,7 +22,7 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.stringutils import random_string from tests import unittest diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index a8a9fc5b628e..8a8d369faca1 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import create_requester from synapse.util.stringutils import random_string diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 32651db09669..38e6d9f5363a 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,8 +20,7 @@ from twisted.internet import defer import synapse -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import devices +from synapse.rest.client import devices, login from synapse.types import JsonDict from tests import unittest diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 18e92e90d7f4..0a52bc8b721f 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Optional from unittest.mock import Mock, call from signedjson.key import generate_signing_key @@ -33,7 +33,7 @@ handle_update, ) from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest @@ -339,8 +339,11 @@ def test_persisting_presence_updates(self): class PresenceTimeoutTestCase(unittest.TestCase): + """Tests different timers and that the timer does not change `status_msg` of user.""" + def test_idle_timer(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -348,12 +351,14 @@ def test_idle_timer(self): state=PresenceState.ONLINE, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + self.assertEquals(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -361,6 +366,7 @@ def test_busy_no_idle(self): presence state into unavailable. """ user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -368,15 +374,18 @@ def test_busy_no_idle(self): state=PresenceState.BUSY, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.BUSY) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -384,15 +393,18 @@ def test_sync_timeout(self): state=PresenceState.ONLINE, last_active_ts=0, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -400,6 +412,7 @@ def test_sync_online(self): state=PresenceState.ONLINE, last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -408,9 +421,11 @@ def test_sync_online(self): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.ONLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -419,12 +434,13 @@ def test_federation_ping(self): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state, new_state) + self.assertEquals(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -444,6 +460,7 @@ def test_no_timeout(self): def test_federation_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -452,6 +469,7 @@ def test_federation_timeout(self): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -460,9 +478,11 @@ def test_federation_timeout(self): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -471,6 +491,7 @@ def test_last_active(self): last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, last_user_sync_ts=now, last_federation_update_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) @@ -516,6 +537,144 @@ def test_external_process_timeout(self): ) self.assertEqual(state.state, PresenceState.OFFLINE) + def test_user_goes_offline_by_timeout_status_msg_remain(self): + """Test that if a user doesn't update the records for a while + users presence goes `OFFLINE` because of timeout and `status_msg` remains. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, status_msg) + + # Check that if the timeout fires, then the syncing user gets timed out + self.reactor.advance(SYNC_ONLINE_TIMEOUT) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, status_msg) + + def test_user_goes_offline_manually_with_no_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, None) + + def test_user_goes_offline_manually_with_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and a status is set, that `status_msg` appears. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self._set_presencestate_with_status_msg( + user_id, PresenceState.OFFLINE, "And now here." + ) + + def test_user_reset_online_with_no_status(self): + """Test that if a user set again the presence manually + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online again + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.ONLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, None) + + def test_set_presence_with_status_msg_none(self): + """Test that if a user set again the presence manually + and status is `None`, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online and `status_msg = None` + self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) + + def _set_presencestate_with_status_msg( + self, user_id: str, state: PresenceState, status_msg: Optional[str] + ): + """Set a PresenceState and status_msg and check the result. + + Args: + user_id: User for that the status is to be set. + PresenceState: The new PresenceState. + status_msg: Status message that is to be set. + """ + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), + {"presence": state, "status_msg": status_msg}, + ) + ) + + new_state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(new_state.state, state) + self.assertEqual(new_state.status_msg, status_msg) + class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index c727ac6bd6ad..732a12c9bd08 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -287,6 +287,11 @@ def test_filters_out_receipt_event_with_only_hidden_receipt_and_ignores_rest(sel ) def test_handles_string_data(self): + """ + Tests that an invalid shape for read-receipts is handled. + Context: https://github.com/matrix-org/synapse/issues/10603 + """ + self._test_filters_hidden( [ { @@ -301,19 +306,7 @@ def test_handles_string_data(self): "type": "m.receipt", }, ], - [ - { - "content": { - "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { - "@rikj:jki.re": "string", - } - }, - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - }, - ], + [], ) def _test_filters_hidden( diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py new file mode 100644 index 000000000000..ac800afa7d3e --- /dev/null +++ b/tests/handlers/test_room_summary.py @@ -0,0 +1,959 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Iterable, List, Optional, Tuple +from unittest import mock + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) +from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID + +from tests import unittest + + +def _create_event(room_id: str, order: Optional[Any] = None): + result = mock.Mock() + result.room_id = room_id + result.content = {} + if order is not None: + result.content["order"] = order + return result + + +def _order(*events): + return sorted(events, key=_child_events_comparison_key) + + +class TestSpaceSummarySort(unittest.TestCase): + def test_no_order_last(self): + """An event with no ordering is placed behind those with an ordering.""" + ev1 = _create_event("!abc:test") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order(self): + """The ordering should be used.""" + ev1 = _create_event("!abc:test", "xyz") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order_room_id(self): + """Room ID is a tie-breaker for ordering.""" + ev1 = _create_event("!abc:test", "abc") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev1, ev2], _order(ev1, ev2)) + + def test_invalid_ordering_type(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", 1) + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", {}) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", []) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", True) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_invalid_ordering_value(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", "foo\n") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", "a" * 51) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + +class SpaceSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a space and a child room. + self.space = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + self.room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, self.room, self.token) + + def _add_child( + self, space_id: str, room_id: str, token: str, order: Optional[str] = None + ) -> None: + """Add a child room to a space.""" + content: JsonDict = {"via": [self.hs.hostname]} + if order is not None: + content["order"] = order + self.helper.send_state( + space_id, + event_type=EventTypes.SpaceChild, + body=content, + tok=token, + state_key=room_id, + ) + + def _assert_rooms( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs and events are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + if children: + children_ids.extend([(room_id, child_id) for child_id in children]) + self.assertCountEqual( + [room.get("room_id") for room in result["rooms"]], room_ids + ) + self.assertCountEqual( + [ + (event.get("room_id"), event.get("state_key")) + for event in result["events"] + ], + children_ids, + ) + + def _assert_hierarchy( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + result_room_ids = [] + result_children_ids = [] + for result_room in result["rooms"]: + result_room_ids.append(result_room["room_id"]) + result_children_ids.append( + [ + (cs["room_id"], cs["state_key"]) + for cs in result_room.get("children_state") + ] + ) + + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + children_ids.append([(room_id, child_id) for child_id in children]) + + # Note that order matters. + self.assertEqual(result_room_ids, room_ids) + self.assertEqual(result_children_ids, children_ids) + + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: + """ + Creates a invite (as if received over federation) for the room from the + given hostname. + + Args: + room_id: The room ID to issue an invite for. + fed_hostname: The user to invite from. + """ + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + fed_hostname = UserID.from_string(from_user).domain + event = make_event_from_dict( + { + "room_id": room_id, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": from_user, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + + def test_simple_space(self): + """Test a simple space with a single room.""" + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_visibility(self): + """A user not in a space cannot inspect it.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user can see the space since it is publicly joinable. + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # If the space is made invite-only, it should no longer be viewable. + self.helper.send_state( + self.space, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.invite(self.space, targ=user2, tok=self.token) + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Attempting to view an unknown room returns the same error. + self.get_failure( + self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + self.get_failure( + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, + tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, + ) + self._add_child(self.space, room_id, self.token) + return room_id + + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [ + ( + self.space, + [ + self.room, + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (self.room, ()), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + expected = [ + (self.space, [self.room, room2, subspace]), + (self.room, ()), + (subspace, [subroom, self.room, room2]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_pagination(self): + """Test simple pagination works.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + # The result should have the space and all of the links, plus some of the + # rooms and a pagination token. + expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] + expected += [(room_id, ()) for room_id in room_ids[:6]] + self._assert_hierarchy(result, expected) + self.assertIn("next_batch", result) + + # Check the next page. + result = self.get_success( + self.handler.get_room_hierarchy( + self.user, self.space, limit=5, from_token=result["next_batch"] + ) + ) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(room_id, ()) for room_id in room_ids[6:]] + self._assert_hierarchy(result, expected) + self.assertNotIn("next_batch", result) + + def test_invalid_pagination_token(self): + """An invalid pagination token, or changing other parameters, shoudl be rejected.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + self.assertIn("next_batch", result) + + # Changing the room ID, suggested-only, or max-depth causes an error. + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.room, from_token=result["next_batch"] + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, + self.space, + suggested_only=True, + from_token=result["next_batch"], + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.space, max_depth=0, from_token=result["next_batch"] + ), + SynapseError, + ) + + # An invalid token is ignored. + self.get_failure( + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + SynapseError, + ) + + def test_max_depth(self): + """Create a deep tree to test the max depth against.""" + spaces = [self.space] + rooms = [self.room] + for _ in range(5): + spaces.append( + self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": { + EventContentFields.ROOM_TYPE: RoomTypes.SPACE + } + }, + ) + ) + self._add_child(spaces[-2], spaces[-1], self.token) + rooms.append(self.helper.create_room_as(self.user, tok=self.token)) + self._add_child(spaces[-1], rooms[-1], self.token) + + # Test just the space itself. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + ) + expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] + self._assert_hierarchy(result, expected) + + # A single additional layer. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + ) + expected += [ + (rooms[0], ()), + (spaces[1], [rooms[1], spaces[2]]), + ] + self._assert_hierarchy(result, expected) + + # A few layers. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + ) + expected += [ + (rooms[1], ()), + (spaces[2], [rooms[2], spaces[3]]), + (rooms[2], ()), + (spaces[3], [rooms[3], spaces[4]]), + ] + self._assert_hierarchy(result, expected) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + # Generate some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + requested_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": subroom, + "world_readable": True, + } + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [ + requested_room_entry, + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return requested_room_entry, {subroom: child_room}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + (subspace, [subroom]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + # Poke an invite over federation into the database. + self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) + + # Note that these entries are brief, but should contain enough info. + children_rooms = ( + ( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + ( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + ( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + ( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + ( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + ( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ) + + subspace_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": room_id, + "content": {"via": [fed_hostname]}, + } + for room_id, _ in children_rooms + ], + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [subspace_room_entry] + [ + # A copy is made of the room data since the allowed_spaces key + # is removed. + _RoomEntry(child_room[0], dict(child_room[1])) + for child_room in children_rooms + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return subspace_room_entry, dict(children_rooms), set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + ( + subspace, + [ + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_invited(self): + """ + A room which the user was invited to should be included in the response. + + This differs from test_fed_filtering in that the room itself is being + queried over federation, instead of it being included as a sub-room of + a space in the response. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#subroom:" + fed_hostname + + # Poke an invite over federation into the database. + self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + + fed_room_entry = _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [fed_room_entry] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return fed_room_entry, {}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_room, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, fed_room]), + (self.room, ()), + (fed_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + +class RoomSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a simple room. + self.room = self.helper.create_room_as(self.user, tok=self.token) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + def test_own_room(self): + """Test a simple room created by the requester.""" + result = self.get_success(self.handler.get_room_summary(self.user, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + def test_visibility(self): + """A user not in a private room cannot get its summary.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user cannot see the room. + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made world-readable it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made public it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.PUBLIC}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Join the space, make it invite-only again and results should be returned. + self.helper.join(self.room, user2, tok=token2) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py deleted file mode 100644 index 01975c13d4fc..000000000000 --- a/tests/handlers/test_space_summary.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterable, Optional, Tuple -from unittest import mock - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RestrictedJoinRuleTypes, - RoomTypes, -) -from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key -from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.server import HomeServer -from synapse.types import JsonDict - -from tests import unittest - - -def _create_event(room_id: str, order: Optional[Any] = None): - result = mock.Mock() - result.room_id = room_id - result.content = {} - if order is not None: - result.content["order"] = order - return result - - -def _order(*events): - return sorted(events, key=_child_events_comparison_key) - - -class TestSpaceSummarySort(unittest.TestCase): - def test_no_order_last(self): - """An event with no ordering is placed behind those with an ordering.""" - ev1 = _create_event("!abc:test") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order(self): - """The ordering should be used.""" - ev1 = _create_event("!abc:test", "xyz") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order_room_id(self): - """Room ID is a tie-breaker for ordering.""" - ev1 = _create_event("!abc:test", "abc") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev1, ev2], _order(ev1, ev2)) - - def test_invalid_ordering_type(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", 1) - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", {}) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", []) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", True) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_invalid_ordering_value(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", "foo\n") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", "a" * 51) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - -class SpaceSummaryTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, hs: HomeServer): - self.hs = hs - self.handler = self.hs.get_space_summary_handler() - - # Create a user. - self.user = self.register_user("user", "pass") - self.token = self.login("user", "pass") - - # Create a space and a child room. - self.space = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - self.room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, self.room, self.token) - - def _add_child(self, space_id: str, room_id: str, token: str) -> None: - """Add a child room to a space.""" - self.helper.send_state( - space_id, - event_type=EventTypes.SpaceChild, - body={"via": [self.hs.hostname]}, - tok=token, - state_key=room_id, - ) - - def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None: - """Assert that the expected room IDs are in the response.""" - self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms) - - def _assert_events( - self, result: JsonDict, events: Iterable[Tuple[str, str]] - ) -> None: - """Assert that the expected parent / child room IDs are in the response.""" - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - events, - ) - - def test_simple_space(self): - """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should have the space and the room in it, along with a link - # from space -> room. - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) - - def test_visibility(self): - """A user not in a space cannot inspect it.""" - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # The user cannot see the space. - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - - # If the space is made world-readable it should return a result. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) - - # Make it not world-readable again and confirm it results in an error. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.JOINED}, - tok=self.token, - ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - - # Join the space and results should be returned. - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) - - def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content - ) -> str: - """Create a room with the given join rule and add it to the space.""" - room_id = self.helper.create_room_as( - self.user, - room_version=room_version, - tok=self.token, - extra_content={ - "initial_state": [ - { - "type": EventTypes.JoinRules, - "state_key": "", - "content": { - "join_rule": join_rule, - **extra_content, - }, - } - ] - }, - ) - self._add_child(self.space, room_id, self.token) - return room_id - - def test_filtering(self): - """ - Rooms should be properly filtered to only include rooms the user has access to. - """ - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # Create a few rooms which will have different properties. - public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) - knock_room = self._create_room_with_join_rule( - JoinRules.KNOCK, room_version=RoomVersions.V7.identifier - ) - not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(invited_room, targ=user2, tok=self.token) - restricted_room = self._create_room_with_join_rule( - JoinRules.RESTRICTED, - room_version=RoomVersions.V8.identifier, - allow=[], - ) - restricted_accessible_room = self._create_room_with_join_rule( - JoinRules.RESTRICTED, - room_version=RoomVersions.V8.identifier, - allow=[ - { - "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, - "room_id": self.space, - "via": [self.hs.hostname], - } - ], - ) - world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.send_state( - world_readable_room, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - joined_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(joined_room, targ=user2, tok=self.token) - self.helper.join(joined_room, user2, tok=token2) - - # Join the space. - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - self._assert_rooms( - result, - [ - self.space, - self.room, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, public_room), - (self.space, knock_room), - (self.space, not_invited_room), - (self.space, invited_room), - (self.space, restricted_room), - (self.space, restricted_accessible_room), - (self.space, world_readable_room), - (self.space, joined_room), - ], - ) - - def test_complex_space(self): - """ - Create a "complex" space to see how it handles things like loops and subspaces. - """ - # Create an inaccessible room. - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) - # This is a bit odd as "user" is adding a room they don't know about, but - # it works for the tests. - self._add_child(self.space, room2, self.token) - - # Create a subspace under the space with an additional room in it. - subspace = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - subroom = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, subspace, token=self.token) - self._add_child(subspace, subroom, token=self.token) - # Also add the two rooms from the space into this subspace (causing loops). - self._add_child(subspace, self.room, token=self.token) - self._add_child(subspace, room2, self.token) - - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - - # The result should include each room a single time and each link. - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, room2), - (self.space, subspace), - (subspace, subroom), - (subspace, self.room), - (subspace, room2), - ], - ) - - def test_fed_complex(self): - """ - Return data over federation and ensure that it is handled properly. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - subroom = "#subroom:" + fed_hostname - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - # Return some good data, and some bad data: - # - # * Event *back* to the root room. - # * Unrelated events / rooms - # * Multiple levels of events (in a not-useful order, e.g. grandchild - # events before child events). - - # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, - ] - return rooms, events - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, subroom), - ], - ) - - def test_fed_filtering(self): - """ - Rooms returned over federation should be properly filtered to only include - rooms the user has access to. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - - # Create a few rooms which will have different properties. - public_room = "#public:" + fed_hostname - knock_room = "#knock:" + fed_hostname - not_invited_room = "#not_invited:" + fed_hostname - invited_room = "#invited:" + fed_hostname - restricted_room = "#restricted:" + fed_hostname - restricted_accessible_room = "#restricted_accessible:" + fed_hostname - world_readable_room = "#world_readable:" + fed_hostname - joined_room = self.helper.create_room_as(self.user, tok=self.token) - - # Poke an invite over federation into the database. - fed_handler = self.hs.get_federation_handler() - event = make_event_from_dict( - { - "room_id": invited_room, - "event_id": "!abcd:" + fed_hostname, - "type": EventTypes.Member, - "sender": "@remote:" + fed_hostname, - "state_key": self.user, - "content": {"membership": Membership.INVITE}, - "prev_events": [], - "auth_events": [], - "depth": 1, - "origin_server_ts": 1234, - } - ) - self.get_success( - fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) - ) - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms - ] - - # Also include the subspace. - rooms.insert( - 0, - { - "room_id": subspace, - "world_readable": True, - }, - ) - return rooms, events - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - self._assert_rooms( - result, - [ - self.space, - self.room, - subspace, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, public_room), - (subspace, knock_room), - (subspace, not_invited_room), - (subspace, invited_room), - (subspace, restricted_room), - (subspace, restricted_accessible_room), - (subspace, world_readable_room), - (subspace, joined_room), - ], - ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e4059acda356..1ba4c05b9ba0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main import stats from tests import unittest diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 549876dc85c0..e44bf2b3b187 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -18,8 +18,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client import login, room, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a37bce08c33a..992d8f94fd70 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging -from typing import Optional -from unittest.mock import Mock +import os +from typing import Iterable, Optional +from unittest.mock import Mock, patch import treq from netaddr import IPSet @@ -22,11 +24,12 @@ from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions +from twisted.internet.interfaces import IProtocolFactory from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent -from twisted.web.http import HTTPChannel +from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS @@ -49,24 +52,6 @@ logger = logging.getLogger(__name__) -test_server_connection_factory = None - - -def get_connection_factory(): - # this needs to happen once, but not until we are ready to run the first test - global test_server_connection_factory - if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory( - sanlist=[ - b"DNS:testserv", - b"DNS:target-server", - b"DNS:xn--bcher-kva.com", - b"IP:1.2.3.4", - b"IP:::1", - ] - ) - return test_server_connection_factory - # Once Async Mocks or lambdas are supported this can go away. def generate_resolve_service(result): @@ -100,24 +85,38 @@ def setUp(self): had_well_known_cache=self.had_well_known_cache, ) - self.agent = MatrixFederationAgent( - reactor=self.reactor, - tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. - ip_blacklist=IPSet(), - _srv_resolver=self.mock_resolver, - _well_known_resolver=self.well_known_resolver, - ) - - def _make_connection(self, client_factory, expected_sni): + def _make_connection( + self, + client_factory: IProtocolFactory, + ssl: bool = True, + expected_sni: bytes = None, + tls_sanlist: Optional[Iterable[bytes]] = None, + ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection + Args: + client_factory: the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + ssl: If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + False is set only for when proxy expect http connection. + Otherwise federation requests use always https. + + expected_sni: the expected SNI value + + tls_sanlist: list of SAN entries for the TLS cert presented by the server. Returns: - HTTPChannel: the test server + the server Protocol returned by server_factory """ # build the test server - server_tls_protocol = _build_test_server(get_connection_factory()) + server_factory = _get_test_protocol_factory() + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) + + server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an @@ -128,35 +127,39 @@ def _make_connection(self, client_factory, expected_sni): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol) + FakeTransport(server_protocol, self.reactor, client_protocol) ) - # tell the server tls protocol to send its stuff back to the client, too - server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol) + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) ) - # grab a hold of the TLS connection, in case it gets torn down - server_tls_connection = server_tls_protocol._tlsConnection - - # fish the test server back out of the server-side TLS protocol. - http_protocol = server_tls_protocol.wrappedProtocol + if ssl: + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_protocol.wrappedProtocol + # grab a hold of the TLS connection, in case it gets torn down + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None - # give the reactor a pump to get the TLS juices flowing. - self.reactor.pump((0.1,)) + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) # check the SNI - server_name = server_tls_connection.get_servername() - self.assertEqual( - server_name, - expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), - ) + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri): + def _make_get_request(self, uri: bytes): """ Sends a simple GET request via the agent, and checks its logcontext management """ @@ -180,20 +183,20 @@ def _make_get_request(self, uri): def _handle_well_known_connection( self, - client_factory, - expected_sni, - content, + client_factory: IProtocolFactory, + expected_sni: bytes, + content: bytes, response_headers: Optional[dict] = None, - ): + ) -> HTTPChannel: """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: - client_factory (IProtocolFactory): outgoing connection - expected_sni (bytes): SNI that we expect the outgoing connection to send - content (bytes): content to send back as the .well-known + client_factory: outgoing connection + expected_sni: SNI that we expect the outgoing connection to send + content: content to send back as the .well-known Returns: - HTTPChannel: server impl + server impl """ # make the connection for .well-known well_known_server = self._make_connection( @@ -209,7 +212,10 @@ def _handle_well_known_connection( return well_known_server def _send_well_known_response( - self, request, content, headers: Optional[dict] = None + self, + request: Request, + content: bytes, + headers: Optional[dict] = None, ): """Check that an incoming request looks like a valid .well-known request, and send back the response. @@ -225,10 +231,37 @@ def _send_well_known_response( self.reactor.pump((0.1,)) - def test_get(self): + def _make_agent(self) -> MatrixFederationAgent: """ - happy-path test of a GET request with an explicit port + If a proxy server is set, the MatrixFederationAgent must be created again + because it is created too early during setUp """ + return MatrixFederationAgent( + reactor=self.reactor, + tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_whitelist=IPSet(), + ip_blacklist=IPSet(), + _srv_resolver=self.mock_resolver, + _well_known_resolver=self.well_known_resolver, + ) + + def test_get(self): + """happy-path test of a GET request with an explicit port""" + self._do_get() + + @patch.dict( + os.environ, + {"https_proxy": "proxy.com", "no_proxy": "testserv"}, + ) + def test_get_bypass_proxy(self): + """test of a GET request with an explicit port and bypass proxy""" + self._do_get() + + def _do_get(self): + """test of a GET request with an explicit port""" + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") @@ -282,10 +315,188 @@ def test_get(self): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) + @patch.dict( + os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_http_proxy(self): + """test for federation request through a http proxy""" + self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_http_proxy_with_auth(self): + """test for federation request through a http proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" + ) + + @patch.dict( + os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_https_proxy(self): + """test for federation request through a https proxy""" + self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_https_proxy_with_auth(self): + """test for federation request through a https proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" + ) + + def _do_get_via_proxy( + self, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, + ): + """Send a https federation request via an agent and check that it is correctly + received at the proxy and client. The proxy can use either http or https. + Args: + expect_proxy_ssl: True if we expect the request to connect to the proxy via https. + expected_auth_credentials: credentials we expect to be presented to authenticate at the proxy + """ + self.agent = self._make_agent() + + self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["proxy.com"] = "9.9.9.9" + test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + # make sure we are connecting to the proxy + self.assertEqual(host, "9.9.9.9") + self.assertEqual(port, 1080) + + # make a test server to act as the proxy, and wire up the client + proxy_server = self._make_connection( + client_factory, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, + ) + + assert isinstance(proxy_server, HTTPChannel) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"testserv:8448") + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if expected_auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(expected_auth_credentials) + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + request.finish() + + # now we make another test server to act as the upstream HTTP server. + server_ssl_protocol = _wrap_server_factory_for_tls( + _get_test_protocol_factory() + ).buildProtocol(None) + + # Tell the HTTP server to send outgoing traffic back via the proxy's transport. + proxy_server_transport = proxy_server.transport + server_ssl_protocol.makeConnection(proxy_server_transport) + + # ... and replace the protocol on the proxy's transport with the + # TLSMemoryBIOProtocol for the test server, so that incoming traffic + # to the proxy gets sent over to the HTTP(s) server. + + # See also comment at `_do_https_request_via_proxy` + # in ../test_proxyagent.py for more details + if expect_proxy_ssl: + assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) + proxy_server_transport.wrappedProtocol = server_ssl_protocol + else: + assert isinstance(proxy_server_transport, FakeTransport) + client_protocol = proxy_server_transport.other + c2s_transport = client_protocol.transport + c2s_transport.other = server_ssl_protocol + + self.reactor.advance(0) + + server_name = server_ssl_protocol._tlsConnection.get_servername() + expected_sni = b"testserv" + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) + + # now there should be a pending request + http_server = server_ssl_protocol.wrappedProtocol + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) + # Check that the destination server DID NOT receive proxy credentials + self.assertIsNone(request.requestHeaders.getRawHeaders(b"Proxy-Authorization")) + content = request.content.read() + self.assertEqual(content, b"") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # send the headers + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") + + self.reactor.pump((0.1,)) + + response = self.successResultOf(test_d) + + # that should give us a Response object + self.assertEqual(response.code, 200) + + # Send the body + request.write('{ "a": 1 }'.encode("ascii")) + request.finish() + + self.reactor.pump((0.1,)) + + # check it can be read + json = self.successResultOf(treq.json_content(response)) + self.assertEqual(json, {"a": 1}) + def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" @@ -320,6 +531,7 @@ def test_get_ipv6_address(self): Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -355,6 +567,7 @@ def test_get_ipv6_address_with_port(self): Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -389,6 +602,8 @@ def test_get_hostname_bad_cert(self): """ Test the behaviour when the certificate on the server doesn't match the hostname """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" @@ -441,6 +656,8 @@ def test_get_ip_address_bad_cert(self): Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" @@ -471,6 +688,7 @@ def test_get_no_srv_no_well_known(self): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -524,6 +742,7 @@ def test_get_no_srv_no_well_known(self): def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -587,6 +806,8 @@ def test_get_well_known_redirect(self): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -675,6 +896,7 @@ def test_get_invalid_well_known(self): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -743,6 +965,7 @@ def test_get_well_known_unsigned_cert(self): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( @@ -780,6 +1003,8 @@ def test_get_hostname_srv(self): """ Test the behaviour when there is a single SRV record """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"srvtarget", port=8443)] ) @@ -820,6 +1045,8 @@ def test_get_well_known_srv(self): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" @@ -876,6 +1103,7 @@ def test_get_well_known_srv(self): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) @@ -937,6 +1165,7 @@ def test_idna_servername(self): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com @@ -1140,6 +1369,8 @@ def test_well_known_too_large(self): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails.""" + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [ Server(host=b"target.com", port=8443), @@ -1266,34 +1497,49 @@ def _check_logcontext(context): raise AssertionError("Expected logcontext %s but was %s" % (context, current)) -def _build_test_server(connection_creator): - """Construct a test server - - This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol - +def _wrap_server_factory_for_tls( + factory: IProtocolFactory, sanlist: Iterable[bytes] = None +) -> IProtocolFactory: + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + factory: protocol factory to wrap + sanlist: list of domains the cert should be valid for + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + +def _get_test_protocol_factory() -> IProtocolFactory: + """Get a protocol Factory which will build an HTTPChannel Returns: - TLSMemoryBIOProtocol + interfaces.IProtocolFactory """ server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. server_factory.log = _log_request - server_tls_factory = TLSMemoryBIOFactory( - connection_creator, isClient=False, wrappedFactory=server_factory - ) - - return server_tls_factory.buildProtocol(None) + return server_factory -def _log_request(request): +def _log_request(request: str): """Implements Factory.log, which is expected by Request.finish""" - logger.info("Completed request %s", request) + logger.info(f"Completed request {request}") @implementer(IPolicyForHTTPS) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index e5865c161d5e..2db77c6a7345 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -29,7 +29,8 @@ from twisted.web.http import HTTPChannel from synapse.http.client import BlacklistingReactorWrapper -from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy +from synapse.http.connectproxyclient import ProxyCredentials +from synapse.http.proxyagent import ProxyAgent, parse_proxy from tests.http import TestServerTLSConnectionFactory, get_test_https_policy from tests.server import FakeTransport, ThreadedMemoryReactorClock @@ -392,7 +393,9 @@ def test_http_request_via_proxy(self): """ Tests that requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -402,13 +405,17 @@ def test_http_request_via_proxy_with_auth(self): """ Tests that authenticated requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) def test_http_request_via_https_proxy(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -418,12 +425,16 @@ def test_http_request_via_https_proxy(self): }, ) def test_http_request_via_https_proxy_with_auth(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -431,14 +442,18 @@ def test_https_request_via_proxy(self): ) def test_https_request_via_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) def test_https_request_via_https_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -446,20 +461,22 @@ def test_https_request_via_https_proxy(self): ) def test_https_request_via_https_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) def _do_http_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ - if ssl: + if expect_proxy_ssl: agent = ProxyAgent( self.reactor, use_proxy=True, contextFactory=get_test_https_policy() ) @@ -480,9 +497,9 @@ def _do_http_request_via_proxy( http_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) # the FakeTransport is async, so we need to pump the reactor @@ -498,9 +515,9 @@ def _do_http_request_via_proxy( b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -523,14 +540,14 @@ def _do_http_request_via_proxy( def _do_https_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ agent = ProxyAgent( self.reactor, @@ -552,9 +569,9 @@ def _do_https_request_via_proxy( proxy_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) assert isinstance(proxy_server, HTTPChannel) @@ -570,9 +587,9 @@ def _do_https_request_via_proxy( b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -606,7 +623,7 @@ def _do_https_request_via_proxy( # Protocol to implement the proxy, which starts out by forwarding to an # HTTPChannel (to implement the CONNECT command) and can then be switched # into a mode where it forwards its traffic to another Protocol.) - if ssl: + if expect_proxy_ssl: assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) proxy_server_transport.wrappedProtocol = server_ssl_protocol else: diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f48474..7dd519cd44a4 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -20,7 +20,7 @@ from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence @@ -79,6 +79,16 @@ def test_can_register_user(self): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e04bc5c9a661..e0a3342088d4 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -21,7 +21,7 @@ import synapse.rest.admin from synapse.api.errors import Codes, SynapseError -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase @@ -45,14 +45,6 @@ class EmailPusherTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): - # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] - - def sendmail(*args, **kwargs): - d = Deferred() - self.email_attempts.append((d, args, kwargs)) - return d - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -75,7 +67,17 @@ def sendmail(*args, **kwargs): config["public_baseurl"] = "aaa" config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + hs.get_send_email_handler()._sendmail = sendmail return hs diff --git a/tests/push/test_http.py b/tests/push/test_http.py index ffd75b14914f..c068d329a98b 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -18,8 +18,7 @@ import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import receipts +from synapse.rest.client import login, receipts, room from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 666008425a58..f198a9488746 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -24,7 +24,7 @@ EventsStreamRow, ) from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseStreamTestCase from tests.test_utils.event_injection import inject_event, inject_member_event diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 1346e0e160a4..43a16bb141db 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, make_request diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b9751efdc53b..995097d72ccc 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index a0c710f85568..af5dfca752b9 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -17,7 +17,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index ffa425328f0d..ac419f0db37d 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -22,7 +22,7 @@ from twisted.web.server import Request from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.server import HomeServer from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 1e4e3821b9df..4094a75f363c 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index f3615af97e86..0a6e4795ee92 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -16,8 +16,7 @@ from synapse.api.room_versions import RoomVersion from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index a7c6e595b983..bfa638fb4b55 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -24,8 +24,7 @@ from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import groups +from synapse.rest.client import groups, login, room from tests import unittest from tests.server import FakeSite, make_request diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 120730b76417..c4afe5c3d90b 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -17,7 +17,7 @@ import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index f15d1cf6f7c8..e9ef89731ffe 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -16,8 +16,7 @@ import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 7198fd293f52..972d60570c6c 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -20,7 +20,7 @@ import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 17ec8bfd3b93..c9d4731017c1 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room +from synapse.rest.client import directory, events, login, room from tests import unittest diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 79cac4266bf1..5cd82209c4a2 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -18,7 +18,7 @@ import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 42f50c092101..ef7727523870 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -15,17 +15,20 @@ import hashlib import hmac import json +import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client.v1 import login, logout, profile, room -from synapse.rest.client.v2_alpha import devices, sync +from synapse.rest.client import devices, login, logout, profile, room, sync +from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.types import JsonDict, UserID from tests import unittest @@ -72,7 +75,7 @@ def test_disabled(self): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -104,7 +107,7 @@ def test_expired_nonce(self): body = json.dumps({"nonce": nonce}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -112,7 +115,7 @@ def test_expired_nonce(self): channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -166,7 +169,7 @@ def test_register_correct_nonce(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -191,13 +194,13 @@ def test_nonce_reuse(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -219,7 +222,7 @@ def nonce(): body = json.dumps({}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -230,28 +233,28 @@ def nonce(): body = json.dumps({"nonce": nonce()}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -262,28 +265,28 @@ def nonce(): body = json.dumps({"nonce": nonce(), "username": "a"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -301,7 +304,7 @@ def nonce(): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -322,11 +325,11 @@ def test_displayname(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -348,11 +351,11 @@ def test_displayname(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -374,7 +377,7 @@ def test_displayname(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") @@ -399,11 +402,11 @@ def test_displayname(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -449,7 +452,7 @@ def test_register_mau_limit_reached(self): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -638,7 +641,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -1085,7 +1088,7 @@ def test_deactivate_user_erase_false(self): content={"erase": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1236,56 +1239,114 @@ def test_user_does_not_exist(self): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_get_user(self): + def test_invalid_parameter(self): """ - Test a simple get of a user. + If parameters are invalid, an error is returned. """ + + # admin not bool channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"admin": "not_bool"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual("User", channel.json_body["displayname"]) - self._check_fields(channel.json_body) + # deactivated not bool + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": "not_bool"}, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def test_get_user_with_sso(self): - """ - Test get a user with SSO details. - """ - self.get_success( - self.store.record_user_external_id( - "auth_provider1", "external_id1", self.other_user - ) + # password not str + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": True}, ) - self.get_success( - self.store.record_user_external_id( - "auth_provider2", "external_id2", self.other_user - ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # password not length + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": "x" * 513}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # user_type not valid channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "new type"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual( - "external_id1", channel.json_body["external_ids"][0]["external_id"] + # external_ids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} + }, ) - self.assertEqual( - "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual( - "external_id2", channel.json_body["external_ids"][1]["external_id"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # threepids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual( - "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"address": "value"}}, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) def test_create_server_admin(self): @@ -1349,6 +1410,12 @@ def test_create_user(self): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + ], "avatar_url": "mxc://fibble/wibble", } @@ -1364,6 +1431,12 @@ def test_create_user(self): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1628,6 +1701,103 @@ def test_set_threepid(self): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + def test_set_external_id(self): + """ + Test setting external id for an other user. + """ + + # Add two external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + # result does not always have the same sort order, therefore it becomes sorted + self.assertEqual( + sorted(channel.json_body["external_ids"], key=lambda k: k["auth_provider"]), + [ + {"auth_provider": "auth_provider1", "external_id": "external_id1"}, + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + ], + ) + self._check_fields(channel.json_body) + + # Set a new and remove an external_id + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider3", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Remove external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["external_ids"])) + def test_deactivate_user(self): """ Test deactivating another user. @@ -2180,7 +2350,7 @@ def test_user_is_not_local(self): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): """ @@ -2249,6 +2419,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() + self.filepaths = MediaFilePaths(hs.config.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2258,37 +2429,34 @@ def prepare(self, reactor, clock, hs): self.other_user ) - def test_no_auth(self): - """ - Try to list media of an user without authentication. - """ - channel = self.make_request("GET", self.url, b"{}") + @parameterized.expand(["GET", "DELETE"]) + def test_no_auth(self, method: str): + """Try to list media of an user without authentication.""" + channel = self.make_request(method, self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error is returned. - """ + @parameterized.expand(["GET", "DELETE"]) + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error is returned.""" other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_does_not_exist(self, method: str): + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -2296,25 +2464,22 @@ def test_user_does_not_exist(self): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): - """ - Tests that a lookup for a user that is not a local returns a 400 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_is_not_local(self, method: str): + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) - def test_limit(self): - """ - Testing list of media with limit - """ + def test_limit_GET(self): + """Testing list of media with limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2326,16 +2491,31 @@ def test_limit(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["media"]) - def test_from(self): - """ - Testing list of media with a defined starting point (from) - """ + def test_limit_DELETE(self): + """Testing delete of media with limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["deleted_media"]), 5) + + def test_from_GET(self): + """Testing list of media with a defined starting point (from)""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2347,16 +2527,31 @@ def test_from(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def test_limit_and_from(self): - """ - Testing list of media with a defined starting point and limit - """ + def test_from_DELETE(self): + """Testing delete of media with a defined starting point (from)""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 15) + self.assertEqual(len(channel.json_body["deleted_media"]), 15) + + def test_limit_and_from_GET(self): + """Testing list of media with a defined starting point and limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2368,59 +2563,78 @@ def test_limit_and_from(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_invalid_parameter(self): - """ - If parameters are invalid, an error is returned. - """ + def test_limit_and_from_DELETE(self): + """Testing delete of media with a defined starting point and limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["deleted_media"]), 10) + + @parameterized.expand(["GET", "DELETE"]) + def test_invalid_parameter(self, method: str): + """If parameters are invalid, an error is returned.""" # unkown order_by channel = self.make_request( - "GET", + method, self.url + "?order_by=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order channel = self.make_request( - "GET", + method, self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit channel = self.make_request( - "GET", + method, self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from channel = self.make_request( - "GET", + method, self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): """ Testing that `next_token` appears at the right place + + For deleting media `next_token` is not useful, because + after deleting media the media has a new order. """ number_media = 20 @@ -2435,7 +2649,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2448,7 +2662,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2461,7 +2675,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2475,12 +2689,12 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_user_has_no_media(self): + def test_user_has_no_media_GET(self): """ Tests that a normal lookup for media is successfully if user has no media created @@ -2496,11 +2710,24 @@ def test_user_has_no_media(self): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) - def test_get_media(self): + def test_user_has_no_media_DELETE(self): """ - Tests that a normal lookup for media is successfully + Tests that a delete is successful if user has no media """ + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["deleted_media"])) + + def test_get_media(self): + """Tests that a normal lookup for media is successful""" + number_media = 5 other_user_tok = self.login("user", "pass") self._create_media_for_user(other_user_tok, number_media) @@ -2517,6 +2744,35 @@ def test_get_media(self): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) + def test_delete_media(self): + """Tests that a normal delete of media is successful""" + + number_media = 5 + other_user_tok = self.login("user", "pass") + media_ids = self._create_media_for_user(other_user_tok, number_media) + + # Test if the file exists + local_paths = [] + for media_id in media_ids: + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + local_paths.append(local_path) + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["deleted_media"])) + self.assertCountEqual(channel.json_body["deleted_media"], media_ids) + + # Test if the file is deleted + for local_path in local_paths: + self.assertFalse(os.path.exists(local_path)) + def test_order_by(self): """ Testing order list with parameter `order_by` @@ -2622,13 +2878,16 @@ def test_order_by(self): [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" ) - def _create_media_for_user(self, user_token: str, number_media: int): + def _create_media_for_user(self, user_token: str, number_media: int) -> List[str]: """ Create a number of media for a specific user Args: user_token: Access token of the user number_media: Number of media to be created for the user + Returns: + List of created media ID """ + media_ids = [] for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2637,7 +2896,9 @@ def _create_media_for_user(self, user_token: str, number_media: int): b"0a2db40000000049454e44ae426082" ) - self._create_media_and_access(user_token, image_data) + media_ids.append(self._create_media_and_access(user_token, image_data)) + + return media_ids def _create_media_and_access( self, @@ -2680,7 +2941,7 @@ def _create_media_and_access( 200, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" % server_and_media_id + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) @@ -2718,12 +2979,12 @@ def _order_test( url = self.url + "?" if order_by is not None: - url += "order_by=%s&" % (order_by,) + url += f"order_by={order_by}&" if dir is not None and dir in ("b", "f"): - url += "dir=%s" % (dir,) + url += f"dir={dir}" channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -2762,7 +3023,7 @@ def _get_token(self) -> str: channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): @@ -2803,7 +3064,7 @@ def test_devices(self): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -2815,11 +3076,11 @@ def test_logout(self): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2829,7 +3090,7 @@ def test_logout(self): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -2840,17 +3101,17 @@ def test_user_logout_all(self): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( @@ -2867,13 +3128,13 @@ def test_admin_logout_all(self): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2883,7 +3144,7 @@ def test_admin_logout_all(self): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3243,7 +3504,7 @@ def test_user_is_not_local(self): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) channel = self.make_request( "POST", @@ -3279,7 +3540,7 @@ def test_invalid_parameter(self): content={"messages_per_second": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3290,7 +3551,7 @@ def test_invalid_parameter(self): content={"messages_per_second": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3301,7 +3562,7 @@ def test_invalid_parameter(self): content={"burst_count": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3312,7 +3573,7 @@ def test_invalid_parameter(self): content={"burst_count": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3337,7 +3598,7 @@ def test_return_zero_when_null(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3351,7 +3612,7 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3362,7 +3623,7 @@ def test_success(self): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3373,7 +3634,7 @@ def test_success(self): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3383,7 +3644,7 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3393,7 +3654,7 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3403,6 +3664,6 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py new file mode 100644 index 000000000000..4e1c49c28b8d --- /dev/null +++ b/tests/rest/admin/test_username_available.py @@ -0,0 +1,62 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.rest.admin +from synapse.api.errors import Codes, SynapseError +from synapse.rest.client import login + +from tests import unittest + + +class UsernameAvailableTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v1/username_available" + + def prepare(self, reactor, clock, hs): + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + async def check_username(username): + if username == "allowed": + return True + raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + handler = self.hs.get_registration_handler() + handler.check_username = check_username + + def test_username_available(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "allowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["available"]) + + def test_username_unavailable(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "disallowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") + self.assertEqual(channel.json_body["error"], "User ID already taken.") diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 5cc62a910a43..65c58ce70a84 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -16,7 +16,7 @@ import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.rest.consent import consent_resource from tests import unittest diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index eec0fc01f938..3d7aa8ec8683 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests import unittest diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 478296ba0efa..ca2e8ff8ef01 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -15,7 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index ba5ad47df5a0..91d0762cb0ab 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index dfd85221d01c..433d715f695d 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e1a6e73e17be..b58452195a82 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -15,7 +15,7 @@ from synapse.api.constants import EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.visibility import filter_events_for_client from tests import unittest diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 288ee128886b..6a0d9a82be93 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -16,8 +16,13 @@ import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import ( + directory, + login, + profile, + room, + room_upgrade_rest_servlet, +) from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 28dd47a28bf4..0ae40296403f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -19,7 +19,7 @@ from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import Requester, StateMap from synapse.util.frozenutils import unfreeze diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 8ed470490b4a..d2181ea9070f 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -15,7 +15,7 @@ import json from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias from synapse.util.stringutils import random_string diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 2789d5154660..a90294003eac 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client import events, login, room from tests import unittest diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 7eba69642a6b..eba3552b19ac 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -24,9 +24,8 @@ import synapse.rest.admin from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import devices, register -from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 597e4c67de4f..1d152352d176 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.handlers.presence import PresenceHandler -from synapse.rest.client.v1 import presence +from synapse.rest.client import presence from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 165ad33fb740..2860579c2e54 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,7 +14,7 @@ """Tests REST events for /profile paths.""" from synapse.rest import admin -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from tests import unittest diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py index d0776160824c..d0ce91ccd95c 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -13,7 +13,7 @@ # limitations under the License. import synapse from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, push_rule, room +from synapse.rest.client import login, push_rule, room from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 3df070c93653..0c9cbb9aff52 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,15 +19,17 @@ import json from typing import Iterable -from unittest.mock import Mock +from unittest.mock import Mock, call from urllib import parse as urlparse +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account +from synapse.rest.client import account, directory, login, profile, room from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string @@ -1124,6 +1126,93 @@ def test_restricted_auth(self): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 44e22ca999cf..b54b00473327 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -17,7 +17,7 @@ from unittest.mock import Mock -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index fc2d35596ea3..954ad1a1fdef 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -47,10 +47,10 @@ class RestHelper: def create_room_as( self, - room_creator: str = None, + room_creator: Optional[str] = None, is_public: bool = True, - room_version: str = None, - tok: str = None, + room_version: Optional[str] = None, + tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, custom_headers: Optional[ diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 317a2287e376..b946fca8b367 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -25,8 +25,7 @@ from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import account, register +from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest @@ -47,12 +46,6 @@ def make_homeserver(self, reactor, clock): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -67,7 +60,16 @@ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): } config["public_baseurl"] = "https://example.com" - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + return hs def prepare(self, reactor, clock, hs): @@ -511,11 +513,6 @@ def make_homeserver(self, reactor, clock): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -530,7 +527,16 @@ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): } config["public_baseurl"] = "https://example.com" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + return self.hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 6b90f838b6da..cf5cfb910c8c 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -19,8 +19,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, auth, devices, register +from synapse.rest.client import account, auth, devices, login, register from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index f80f48a45577..13b3c5f499b7 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -13,8 +13,7 @@ # limitations under the License. import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import capabilities +from synapse.rest.client import capabilities, login from tests import unittest from tests.unittest import override_config @@ -103,7 +102,8 @@ def test_get_change_password_capabilities_password_disabled(self): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_does_not_include_msc3244_fields_by_default(self): + @override_config({"experimental_features": {"msc3244_enabled": False}}) + def test_get_does_not_include_msc3244_fields_when_disabled(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -121,7 +121,6 @@ def test_get_does_not_include_msc3244_fields_by_default(self): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - @override_config({"experimental_features": {"msc3244_enabled": True}}) def test_get_does_include_msc3244_fields_when_enabled(self): localpart = "user" password = "pass" diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index c7e47725b789..475c6bed3d1a 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer from synapse.api.errors import Codes -from synapse.rest.client.v2_alpha import filter +from synapse.rest.client import filter from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index 6f07ff6cbbca..3cf5871899e1 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -17,8 +17,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, password_policy, register +from synapse.rest.client import account, login, password_policy, register from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1cad5f00eb20..fecda037a54b 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -23,8 +23,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import account, account_validity, register, sync +from synapse.rest.client import account, account_validity, login, logout, register, sync from tests import unittest from tests.unittest import override_config @@ -509,10 +508,6 @@ def make_homeserver(self, reactor, clock): } # Email config. - self.email_attempts = [] - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) config["email"] = { "enable_notifs": True, @@ -532,7 +527,13 @@ async def sendmail(*args, **kwargs): } config["public_baseurl"] = "aaa" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastore() diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 2e2f94742ef8..02b5e9a8d0d4 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -19,8 +19,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import register, relations +from synapse.rest.client import login, register, relations, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index a76a6fef1e3f..ee6b0b9ebfb0 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -15,8 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py index c9c99cc5d7a8..6db7062a8e1f 100644 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ b/tests/rest/client/v2_alpha/test_sendtodevice.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import sendtodevice, sync +from synapse.rest.client import login, sendtodevice, sync from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index cedb9614a8ae..283eccd53f95 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import shared_rooms +from synapse.rest.client import login, room, shared_rooms from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 15748ed4fd9f..95be369d4be1 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -21,8 +21,7 @@ ReadReceiptEventFields, RelationTypes, ) -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import knock, read_marker, receipts, sync +from synapse.rest.client import knock, login, read_marker, receipts, room, sync from tests import unittest from tests.federation.transport.test_knocking import ( diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py index 5f3f15fc57cd..72f976d8e2ed 100644 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -15,8 +15,7 @@ from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import login, room, room_upgrade_rest_servlet from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2d6b49692ee7..6085444b9da8 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -30,7 +30,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index ac98259b7ee6..58b399a04377 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -15,8 +15,7 @@ import os import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests import unittest diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3245aa91ca6e..8701b5f7e340 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,8 +19,7 @@ from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 932970fd9ad1..a649e8c61872 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ def test_query_via_event_cache(self): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 77c4fe721c1d..da98733ce8e6 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -17,7 +17,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage import prepare_database from synapse.types import UserID, create_requester diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e57fce9694bf..1c2df54ecc53 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -17,7 +17,7 @@ import synapse.rest.admin from synapse.http.site import XForwardedForRequest -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest from tests.server import make_request diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index d87f124c2638..93136f071793 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main.events import _LinkMap from synapse.types import create_requester diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 617bc8091fa8..f462a8b1c721 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -17,7 +17,7 @@ from synapse.api.room_versions import RoomVersions from synapse.federation.federation_base import event_from_pdu_json from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index e5574063f17f..22a77c3cccc5 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import NotFoundError, SynapseError -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9fa968f6bb30..c72dc40510a4 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest diff --git a/tests/test_federation.py b/tests/test_federation.py index 0ed8326f55b8..3785799f46d2 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,10 +75,8 @@ def setUp(self): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( diff --git a/tests/test_mau.py b/tests/test_mau.py index fa6ef92b3bd8..66111eb3674b 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,7 +17,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService -from synapse.rest.client.v2_alpha import register, sync +from synapse.rest.client import register, sync from tests import unittest from tests.unittest import override_config diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0df480db9f17..67dcf567cdb8 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -17,7 +17,7 @@ from twisted.test.proto_helpers import MemoryReactorClock -from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.rest.client.register import register_servlets from synapse.util import Clock from tests import unittest diff --git a/tox.ini b/tox.ini index da77d124fc0e..5a62ec76c23f 100644 --- a/tox.ini +++ b/tox.ini @@ -49,7 +49,7 @@ lint_targets = contrib synctl synmark - .buildkite + .ci docker # default settings for all tox environments